Skip to content

Commit

Permalink
Split into 'roundtripWebSocketClient'
Browse files Browse the repository at this point in the history
  • Loading branch information
HaraldNordgren committed Sep 28, 2024
1 parent f67ebe1 commit bf9a285
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 53 deletions.
2 changes: 1 addition & 1 deletion internal/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestSubscription(t *testing.T) {
ctx := context.Background()
server := server.RunServer()
defer server.Close()
wsClient := newCountRoundtripWebSocketClient(t, server.URL)
wsClient := newCountWebSocketClient(t, server.URL)

errChan, err := wsClient.Start(ctx)
require.NoError(t, err)
Expand Down
52 changes: 0 additions & 52 deletions internal/integration/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,9 @@ import (
"fmt"
"io"
"net/http"
"strings"
"testing"

"github.com/Khan/genqlient/graphql"
"github.com/gorilla/websocket"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -128,53 +126,3 @@ func newRoundtripGetClient(t *testing.T, endpoint string) graphql.Client {
t: t,
}
}

type roundtripWebSocketClient[T any] struct {
wrapped graphql.WebSocketClient[T]
t *testing.T
}

func (c *roundtripWebSocketClient[_]) Start(ctx context.Context) (errChan chan error, err error) {
return c.wrapped.Start(ctx)
}

func (c *roundtripWebSocketClient[_]) Close() error {
return c.wrapped.Close()
}

func (c *roundtripWebSocketClient[T]) Subscribe(req *graphql.Request, dataChan chan graphql.WsResponse[T], forwardDataFunc graphql.ForwardDataFunction[T]) (string, error) {
return c.wrapped.Subscribe(req, dataChan, forwardDataFunc)
}

func (c *roundtripWebSocketClient[_]) Unsubscribe(subscriptionID string) error {
return c.wrapped.Unsubscribe(subscriptionID)
}

type MyDialer struct {
*websocket.Dialer
}

func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (graphql.WSConn, error) {
conn, resp, err := md.Dialer.DialContext(ctx, urlStr, requestHeader)
resp.Body.Close()
return graphql.WSConn(conn), err
}

func wsAdress(endpoint string) string {
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}
return endpoint
}

func newCountRoundtripWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient[countResponse] {
return &roundtripWebSocketClient[countResponse]{
wrapped: countClientUsingWebSocket(
wsAdress(endpoint),
&MyDialer{Dialer: websocket.DefaultDialer},
nil,
),
t: t,
}
}
61 changes: 61 additions & 0 deletions internal/integration/websocket.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package integration

import (
"context"
"net/http"
"strings"
"testing"

"github.com/Khan/genqlient/graphql"
"github.com/gorilla/websocket"
)

type webSocketClient[T any] struct {
wrapped graphql.WebSocketClient[T]
t *testing.T
}

func (c *webSocketClient[_]) Start(ctx context.Context) (errChan chan error, err error) {
return c.wrapped.Start(ctx)
}

func (c *webSocketClient[_]) Close() error {
return c.wrapped.Close()
}

func (c *webSocketClient[T]) Subscribe(req *graphql.Request, dataChan chan graphql.WsResponse[T], forwardDataFunc graphql.ForwardDataFunction[T]) (string, error) {
return c.wrapped.Subscribe(req, dataChan, forwardDataFunc)
}

func (c *webSocketClient[_]) Unsubscribe(subscriptionID string) error {
return c.wrapped.Unsubscribe(subscriptionID)
}

type MyDialer struct {
*websocket.Dialer
}

func (md *MyDialer) DialContext(ctx context.Context, urlStr string, requestHeader http.Header) (graphql.WSConn, error) {
conn, resp, err := md.Dialer.DialContext(ctx, urlStr, requestHeader)
resp.Body.Close()
return graphql.WSConn(conn), err
}

func wsAdress(endpoint string) string {
if !strings.HasPrefix(endpoint, "ws") {
_, address, _ := strings.Cut(endpoint, "://")
endpoint = "ws://" + address
}
return endpoint
}

func newCountWebSocketClient(t *testing.T, endpoint string) graphql.WebSocketClient[countResponse] {
return &webSocketClient[countResponse]{
wrapped: countClientUsingWebSocket(
wsAdress(endpoint),
&MyDialer{Dialer: websocket.DefaultDialer},
nil,
),
t: t,
}
}

0 comments on commit bf9a285

Please sign in to comment.