Skip to content

Commit

Permalink
conn: add websocket options
Browse files Browse the repository at this point in the history
Refactors conn.WebSocket into its own package and adds websocket dial
options.
  • Loading branch information
andydunstall committed May 31, 2024
1 parent ac91a54 commit 26d9f06
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 100 deletions.
5 changes: 4 additions & 1 deletion agent/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/andydunstall/piko/agent/config"
"github.com/andydunstall/piko/pkg/backoff"
"github.com/andydunstall/piko/pkg/conn"
"github.com/andydunstall/piko/pkg/conn/websocket"
"github.com/andydunstall/piko/pkg/log"
"github.com/andydunstall/piko/pkg/rpc"
"go.uber.org/zap"
Expand Down Expand Up @@ -109,7 +110,9 @@ func (e *Endpoint) connect(ctx context.Context) (rpc.Stream, error) {
e.conf.Server.ReconnectMaxBackoff,
)
for {
c, err := conn.DialWebsocket(ctx, e.serverURL(), e.conf.Auth.APIKey)
c, err := websocket.Dial(
ctx, e.serverURL(), websocket.WithToken(e.conf.Auth.APIKey),
)
if err == nil {
return rpc.NewStream(c, e.rpcServer.Handler(), e.logger), nil
}
Expand Down
4 changes: 4 additions & 0 deletions pkg/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ type RetryableError struct {
err error
}

func NewRetryableError(err error) *RetryableError {
return &RetryableError{err}
}

func (e *RetryableError) Unwrap() error {
return e.err
}
Expand Down
91 changes: 0 additions & 91 deletions pkg/conn/websocket.go

This file was deleted.

113 changes: 113 additions & 0 deletions pkg/conn/websocket/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package websocket

import (
"context"
"fmt"
"io"
"net/http"

"github.com/andydunstall/piko/pkg/conn"
"github.com/gorilla/websocket"
)

// retryableStatusCodes contains a set of HTTP status codes that should be
// retried.
var retryableStatusCodes = map[int]struct{}{
http.StatusRequestTimeout: {},
http.StatusTooManyRequests: {},
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
}

type Options struct {
token string
}

type Option interface {
apply(*Options)
}

type tokenOption string

func (o tokenOption) apply(opts *Options) {
opts.token = string(o)
}

func WithToken(token string) Option {
return tokenOption(token)
}

type Conn struct {
wsConn *websocket.Conn
}

func NewConn(wsConn *websocket.Conn) *Conn {
return &Conn{
wsConn: wsConn,
}
}

func Dial(ctx context.Context, url string, opts ...Option) (*Conn, error) {
options := Options{}
for _, o := range opts {
o.apply(&options)
}

header := make(http.Header)
if options.token != "" {
header.Set("Authorization", "Bearer "+options.token)
}
wsConn, resp, err := websocket.DefaultDialer.DialContext(
ctx, url, header,
)
if err != nil {
if resp != nil {
if _, ok := retryableStatusCodes[resp.StatusCode]; ok {
return nil, conn.NewRetryableError(err)
}
return nil, fmt.Errorf("%d: %w", resp.StatusCode, err)
}
return nil, conn.NewRetryableError(err)
}
return NewConn(wsConn), nil
}

func (c *Conn) ReadMessage() ([]byte, error) {
mt, message, err := c.wsConn.ReadMessage()
if err != nil {
return nil, err
}
if mt != websocket.BinaryMessage {
return nil, fmt.Errorf("unexpected websocket message type: %d", mt)
}
return message, nil
}

func (c *Conn) NextReader() (io.Reader, error) {
mt, r, err := c.wsConn.NextReader()
if err != nil {
return nil, err
}
if mt != websocket.BinaryMessage {
return nil, fmt.Errorf("unexpected websocket message type: %d", mt)
}
return r, nil
}

func (c *Conn) WriteMessage(b []byte) error {
return c.wsConn.WriteMessage(websocket.BinaryMessage, b)
}

func (c *Conn) NextWriter() (io.WriteCloser, error) {
return c.wsConn.NextWriter(websocket.BinaryMessage)
}

func (c *Conn) Addr() string {
return c.wsConn.RemoteAddr().String()
}

func (c *Conn) Close() error {
return c.wsConn.Close()
}
4 changes: 2 additions & 2 deletions server/server/upstream/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"net/http"
"time"

"github.com/andydunstall/piko/pkg/conn"
pikowebsocket "github.com/andydunstall/piko/pkg/conn/websocket"
"github.com/andydunstall/piko/pkg/log"
"github.com/andydunstall/piko/pkg/rpc"
"github.com/andydunstall/piko/server/auth"
Expand Down Expand Up @@ -131,7 +131,7 @@ func (s *Server) listenerRoute(c *gin.Context) {
return
}
stream := rpc.NewStream(
conn.NewWebsocketConn(wsConn),
pikowebsocket.NewConn(wsConn),
s.rpcServer.Handler(),
s.logger,
)
Expand Down
12 changes: 6 additions & 6 deletions server/server/upstream/server_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"testing"
"time"

"github.com/andydunstall/piko/pkg/conn"
"github.com/andydunstall/piko/pkg/conn/websocket"
"github.com/andydunstall/piko/pkg/log"
"github.com/andydunstall/piko/pkg/rpc"
"github.com/andydunstall/piko/server/auth"
Expand Down Expand Up @@ -70,7 +70,7 @@ func TestServer_AddConn(t *testing.T) {
upstreamLn.Addr().String(),
)
rpcServer := newRPCServer()
conn, err := conn.DialWebsocket(context.TODO(), url, "")
conn, err := websocket.Dial(context.TODO(), url)
require.NoError(t, err)

// Add client stream and ensure upstream added to proxy.
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestServer_AddConn(t *testing.T) {
upstreamLn.Addr().String(),
)
rpcServer := newRPCServer()
conn, err := conn.DialWebsocket(context.TODO(), url, "123")
conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123"))
require.NoError(t, err)

// Add client stream and ensure upstream added to proxy.
Expand Down Expand Up @@ -157,7 +157,7 @@ func TestServer_AddConn(t *testing.T) {
upstreamLn.Addr().String(),
)
rpcServer := newRPCServer()
conn, err := conn.DialWebsocket(context.TODO(), url, "123")
conn, err := websocket.Dial(context.TODO(), url, websocket.WithToken("123"))
require.NoError(t, err)

// Add client stream and ensure upstream added to proxy.
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestServer_AddConn(t *testing.T) {
"ws://%s/piko/v1/listener/my-endpoint",
upstreamLn.Addr().String(),
)
_, err = conn.DialWebsocket(context.TODO(), url, "123")
_, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123"))
require.Error(t, err)
})

Expand Down Expand Up @@ -229,7 +229,7 @@ func TestServer_AddConn(t *testing.T) {
"ws://%s/piko/v1/listener/my-endpoint",
upstreamLn.Addr().String(),
)
_, err = conn.DialWebsocket(context.TODO(), url, "123")
_, err = websocket.Dial(context.TODO(), url, websocket.WithToken("123"))
require.Error(t, err)
})

Expand Down

0 comments on commit 26d9f06

Please sign in to comment.