Skip to content
This repository has been archived by the owner on Mar 27, 2024. It is now read-only.

feat: DIDComm Inbound Transport - Support for WebSocket #828

Merged
merged 1 commit into from
Nov 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/aries-agent-rest/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/golang/mock v1.3.1
github.com/google/uuid v1.1.1
github.com/gorilla/mux v1.7.3
github.com/gorilla/websocket v1.4.1
github.com/kr/pretty v0.1.0 // indirect
github.com/multiformats/go-multibase v0.0.1
github.com/multiformats/go-multihash v0.0.8
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/mux v1.7.3 h1:gnP5JzjVOuiZD07fKKToCAOjS0yOpj/qPETTXCCS6hw=
github.com/gorilla/mux v1.7.3/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
Expand Down
2 changes: 1 addition & 1 deletion pkg/didcomm/transport/http/inbound.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import (
"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
)

var logger = log.New("aries-framework/transport")
var logger = log.New("aries-framework/http")

// provider contains dependencies for the HTTP Handler creation and is typically created by using aries.Context()
type provider interface {
Expand Down
139 changes: 139 additions & 0 deletions pkg/didcomm/transport/ws/inbound.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
/*
Copyright SecureKey Technologies Inc. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package ws

import (
"context"
"errors"
"fmt"
"net/http"

"github.com/gorilla/websocket"

"github.com/hyperledger/aries-framework-go/pkg/common/log"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
)

var logger = log.New("aries-framework/ws")

const processFailureErrMsg = "failed to process the message"

// Inbound http(ws) type.
type Inbound struct {
externalAddr string
server *http.Server
}

// NewInbound creates a new WebSocket inbound transport instance.
func NewInbound(internalAddr, externalAddr string) (*Inbound, error) {
if internalAddr == "" {
return nil, errors.New("websocket address is mandatory")
}

if externalAddr == "" {
return &Inbound{externalAddr: internalAddr, server: &http.Server{Addr: internalAddr}}, nil
}

return &Inbound{externalAddr: externalAddr, server: &http.Server{Addr: internalAddr}}, nil
}

// Start the http(ws) server.
func (i *Inbound) Start(prov transport.InboundProvider) error {
handler, err := newInboundHandler(prov)
if err != nil {
return fmt.Errorf("websocket server start failed: %w", err)
}

i.server.Handler = handler

go func() {
if err := i.server.ListenAndServe(); err != http.ErrServerClosed {
logger.Fatalf("websocket server start with address [%s] failed, cause: %s", i.server.Addr, err)
}
}()

return nil
}

// Stop the http(ws) server.
func (i *Inbound) Stop() error {
if err := i.server.Shutdown(context.Background()); err != nil {
return fmt.Errorf("websocket server shutdown failed: %w", err)
}

return nil
}

// Endpoint provides the http(ws) connection details.
func (i *Inbound) Endpoint() string {
return i.externalAddr
}

func newInboundHandler(prov transport.InboundProvider) (http.Handler, error) {
if prov == nil || prov.InboundMessageHandler() == nil {
logger.Errorf("Error creating a new inbound handler: message handler function is nil")
return nil, errors.New("creation of inbound handler failed")
}

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
processRequest(w, r, prov)
}), nil
}

func processRequest(w http.ResponseWriter, r *http.Request, prov transport.InboundProvider) {
upgrader := websocket.Upgrader{}

c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
logger.Errorf("failed to upgrade the connection : %v", err)
return
}

defer func() {
err := c.Close()
if err != nil {
logger.Errorf("failed to close connection: %v", err)
}
}()

for {
_, message, err := c.ReadMessage()
if err != nil {
logger.Errorf("Error reading request message: %v", err)

break
}

unpackMsg, err := prov.Packager().UnpackMessage(message)
if err != nil {
logger.Errorf("failed to unpack msg: %v", err)

err = c.WriteMessage(websocket.TextMessage, []byte(processFailureErrMsg))
if err != nil {
logger.Errorf("error writing the message: %v", err)
}

continue
}

messageHandler := prov.InboundMessageHandler()

resp := ""

err = messageHandler(unpackMsg.Message)
if err != nil {
logger.Errorf("incoming msg processing failed: %v", err)

resp = processFailureErrMsg
}

err = c.WriteMessage(websocket.TextMessage, []byte(resp))
if err != nil {
logger.Errorf("error writing the message: %v", err)
}
}
}
183 changes: 183 additions & 0 deletions pkg/didcomm/transport/ws/inbound_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
/*
Copyright SecureKey Technologies Inc. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package ws

import (
"errors"
"net/http"
"net/url"
"strconv"
"testing"
"time"

"github.com/hyperledger/aries-framework-go/pkg/internal/test/transportutil"

"github.com/gorilla/websocket"
"github.com/stretchr/testify/require"

commontransport "github.com/hyperledger/aries-framework-go/pkg/didcomm/common/transport"
"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
mockpackager "github.com/hyperledger/aries-framework-go/pkg/internal/mock/didcomm/packager"
)

type mockProvider struct {
packagerValue commontransport.Packager
}

func (p *mockProvider) InboundMessageHandler() transport.InboundMessageHandler {
return func(message []byte) error {
logger.Infof("message received is %s", string(message))
if string(message) == "invalid-data" {
return errors.New("error")
}
return nil
}
}

func (p *mockProvider) Packager() commontransport.Packager {
return p.packagerValue
}

func TestInboundTransport(t *testing.T) {
t.Run("test inbound transport - with host/port", func(t *testing.T) {
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))
externalAddr := "http://example.com" + port
inbound, err := NewInbound("localhost"+port, externalAddr)
require.NoError(t, err)
require.Equal(t, externalAddr, inbound.Endpoint())
})

t.Run("test inbound transport - with host/port, no external address", func(t *testing.T) {
internalAddr := "example.com" + ":" + strconv.Itoa(transportutil.GetRandomPort(5))
inbound, err := NewInbound(internalAddr, "")
require.NoError(t, err)
require.Equal(t, internalAddr, inbound.Endpoint())
})

t.Run("test inbound transport - without host/port", func(t *testing.T) {
inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "")
require.NoError(t, err)
require.NotEmpty(t, inbound)
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("data")}}
err = inbound.Start(&mockProvider{packagerValue: mockPackager})
require.NoError(t, err)

err = inbound.Stop()
require.NoError(t, err)
})

t.Run("test inbound transport - nil context", func(t *testing.T) {
inbound, err := NewInbound(":"+strconv.Itoa(transportutil.GetRandomPort(5)), "")
require.NoError(t, err)
require.NotEmpty(t, inbound)

err = inbound.Start(nil)
require.Error(t, err)
})

t.Run("test inbound transport - invalid port number", func(t *testing.T) {
_, err := NewInbound("", "")
require.Error(t, err)
require.Contains(t, err.Error(), "websocket address is mandatory")
})
}

func TestInboundDataProcessing(t *testing.T) {
t.Run("test inbound transport - multiple invocation with same client", func(t *testing.T) {
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))

// initiate inbound with port
inbound, err := NewInbound(port, "")
require.NoError(t, err)
require.NotEmpty(t, inbound)

// start server
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("valid-data")}}
err = inbound.Start(&mockProvider{packagerValue: mockPackager})
require.NoError(t, err)

// create ws client
client, cleanup := websocketClient(t, port)
defer cleanup()

for i := 1; i <= 5; i++ {
err = client.WriteMessage(websocket.TextMessage, []byte("random"))
require.NoError(t, err)

messageType, val, err := client.ReadMessage()
require.NoError(t, err)
require.Equal(t, messageType, websocket.TextMessage)
require.Equal(t, "", string(val))
}
})

t.Run("test inbound transport - unpacking error", func(t *testing.T) {
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))

// initiate inbound with port
inbound, err := NewInbound(port, "")
require.NoError(t, err)
require.NotEmpty(t, inbound)

// start server
mockPackager := &mockpackager.Packager{UnpackErr: errors.New("error unpacking")}
err = inbound.Start(&mockProvider{packagerValue: mockPackager})
require.NoError(t, err)

// create ws client
client, cleanup := websocketClient(t, port)
defer cleanup()

err = client.WriteMessage(websocket.TextMessage, []byte(""))
require.NoError(t, err)

messageType, val, err := client.ReadMessage()
require.NoError(t, err)
require.Equal(t, messageType, websocket.TextMessage)
require.Equal(t, processFailureErrMsg, string(val))
})

t.Run("test inbound transport - message handler error", func(t *testing.T) {
port := ":" + strconv.Itoa(transportutil.GetRandomPort(5))

// initiate inbound with port
inbound, err := NewInbound(port, "")
require.NoError(t, err)
require.NotEmpty(t, inbound)

// start server
mockPackager := &mockpackager.Packager{UnpackValue: &commontransport.Envelope{Message: []byte("invalid-data")}}
err = inbound.Start(&mockProvider{packagerValue: mockPackager})
require.NoError(t, err)

// create ws client
client, cleanup := websocketClient(t, port)
defer cleanup()

err = client.WriteMessage(websocket.TextMessage, []byte(""))
require.NoError(t, err)

messageType, val, err := client.ReadMessage()
require.NoError(t, err)
require.Equal(t, messageType, websocket.TextMessage)
require.Equal(t, processFailureErrMsg, string(val))
})
}

func websocketClient(t *testing.T, port string) (*websocket.Conn, func()) {
require.NoError(t, transportutil.VerifyListener("localhost"+port, time.Second))

u := url.URL{Scheme: "ws", Host: "localhost" + port, Path: ""}
c, resp, err := websocket.DefaultDialer.Dial(u.String(), nil)
require.NoError(t, err)
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.NoError(t, resp.Body.Close())

return c, func() {
require.NoError(t, c.Close())
}
}
Loading