Skip to content

Commit

Permalink
Reimplement
Browse files Browse the repository at this point in the history
  • Loading branch information
d1slike committed Apr 24, 2024
1 parent 5364054 commit f20a333
Show file tree
Hide file tree
Showing 13 changed files with 392 additions and 54 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ import (
"nhooyr.io/websocket"
)

func main() {
func ExampleServer() {
srv := etp.NewServer(etp.WithServerAcceptOptions(&websocket.AcceptOptions{
InsecureSkipVerify: true, //completely ignore CORS checks, enable only for dev purposes
}))
Expand All @@ -59,7 +59,7 @@ func main() {

//callback to handle disconnection
srv.OnDisconnect(func(conn *etp.Conn, err error) {
fmt.Println("disconnected", conn.Id(), err)
fmt.Println("disconnected", conn.Id(), err, etp.IsNormalClose(err))
})

//callback to handle any error during serving
Expand Down Expand Up @@ -100,7 +100,7 @@ func main() {
OnDisconnect(func(conn *etp.Conn, err error) { //basically you have all handlers like a server here
fmt.Println("client disconnected", conn.Id(), err)
})
err := cli.Dial("ws://localhost:8080/ws")
err := cli.Dial(context.Background(), "ws://localhost:8080/ws")
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -132,7 +132,10 @@ func main() {
panic(err)
}

time.Sleep(1 * time.Second)
time.Sleep(15 * time.Second)

//call to disconnect all clients
srv.Shutdown()
}
```

Expand Down
168 changes: 168 additions & 0 deletions acceptance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package etp_test

import (
"context"
"net/http/httptest"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/txix-open/etp/v3"
"github.com/txix-open/etp/v3/msg"
"nhooyr.io/websocket"
)

type CallHandler struct {
onConnectCount atomic.Int32
onDisconnectCount atomic.Int32
onErrorCount atomic.Int32
events map[string][][]byte
lock sync.Locker
}

func NewCallHandler() *CallHandler {
return &CallHandler{
events: make(map[string][][]byte),
lock: &sync.Mutex{},
}
}

func (c *CallHandler) OnConnect(conn *etp.Conn) {
c.onConnectCount.Add(1)
}

func (c *CallHandler) OnDisconnect(conn *etp.Conn, err error) {
c.onDisconnectCount.Add(1)
}

func (c *CallHandler) OnError(conn *etp.Conn, err error) {
c.onErrorCount.Add(1)
}

func (c *CallHandler) Handle(response []byte) etp.Handler {
return etp.HandlerFunc(func(ctx context.Context, conn *etp.Conn, event msg.Event) []byte {
c.lock.Lock()
defer c.lock.Unlock()

c.events[event.Name] = append(c.events[event.Name], event.Data)
return response
})
}

func TestAcceptance(t *testing.T) {
t.Parallel()
require := require.New(t)

srvHandler := NewCallHandler()
configureSrv := func(srv *etp.Server) {
srv.OnConnect(srvHandler.OnConnect).
OnDisconnect(srvHandler.OnDisconnect).
OnError(srvHandler.OnError).
On("simpleEvent", srvHandler.Handle([]byte("simpleEventResponse"))).
On("ackEvent", srvHandler.Handle([]byte("ackEventResponse"))).
On("emitInHandlerEvent", etp.HandlerFunc(func(ctx context.Context, conn *etp.Conn, event msg.Event) []byte {
srvHandler.Handle(event.Data).Handle(ctx, conn, event)
resp, err := conn.EmitWithAck(ctx, "echo", event.Data)
require.NoError(err)
require.EqualValues(event.Data, resp)
return nil
})).
OnUnknownEvent(srvHandler.Handle(nil))
}

cliHandler := NewCallHandler()
configureCli := func(cli *etp.Client) {
cli.OnConnect(cliHandler.OnConnect).
OnDisconnect(cliHandler.OnDisconnect).
OnError(cliHandler.OnError).
On("echo", etp.HandlerFunc(func(ctx context.Context, conn *etp.Conn, event msg.Event) []byte {
return event.Data
}))
}

cli, _, _ := serve(t, configureSrv, configureCli)

err := cli.Emit(context.Background(), "simpleEvent", []byte("simpleEventPayload"))
require.NoError(err)

resp, err := cli.EmitWithAck(context.Background(), "ackEvent", nil)
require.NoError(err)
require.EqualValues([]byte("ackEventResponse"), resp)

resp, err = cli.EmitWithAck(context.Background(), "emitInHandlerEvent", []byte("emitInHandlerEventPayload"))
require.NoError(err)
require.Nil(resp)

err = cli.Emit(context.Background(), "unknownEvent", []byte("unknownEventPayload"))
require.NoError(err)

time.Sleep(1 * time.Second)

err = cli.Close()
require.NoError(err)

time.Sleep(1 * time.Second)

require.EqualValues(1, srvHandler.onConnectCount.Load())
require.EqualValues(1, srvHandler.onDisconnectCount.Load())
require.EqualValues(0, srvHandler.onErrorCount.Load())

require.Len(srvHandler.events["simpleEvent"], 1)
require.EqualValues([]byte("simpleEventPayload"), srvHandler.events["simpleEvent"][0])

require.Len(srvHandler.events["ackEvent"], 1)
require.Nil(srvHandler.events["ackEvent"][0])

require.Len(srvHandler.events["emitInHandlerEvent"], 1)
require.EqualValues([]byte("emitInHandlerEventPayload"), srvHandler.events["emitInHandlerEvent"][0])

require.Len(srvHandler.events["unknownEvent"], 1)
require.EqualValues([]byte("unknownEventPayload"), srvHandler.events["unknownEvent"][0])

require.EqualValues(1, cliHandler.onConnectCount.Load())
require.EqualValues(1, cliHandler.onDisconnectCount.Load())
require.EqualValues(0, cliHandler.onErrorCount.Load())
}

func serve(
t *testing.T,
srvCfg func(srv *etp.Server),
cliCfg func(cli *etp.Client),
) (*etp.Client, *etp.Server, *httptest.Server) {
t.Helper()

srv := etp.NewServer(etp.WithServerAcceptOptions(
&websocket.AcceptOptions{
InsecureSkipVerify: true,
},
))
if srvCfg != nil {
srvCfg(srv)
}
testServer := httptest.NewServer(srv)
t.Cleanup(func() {
testServer.Close()
srv.Shutdown()
})

cli := etp.NewClient(etp.WithClientDialOptions(
&websocket.DialOptions{
HTTPClient: testServer.Client(),
},
))
if cliCfg != nil {
cliCfg(cli)
}
err := cli.Dial(context.Background(), strings.ReplaceAll(testServer.URL, "http", "ws"))
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = cli.Close()
})

return cli, srv, testServer
}
16 changes: 9 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ var (
type Client struct {
mux *mux
idGenerator *internal.IdGenerator
clientOpts *ClientOptions
opts *clientOptions
conn *Conn
lock sync.Locker
}

func NewClient(opts ...ClientOption) *Client {
options := DefaultClientOptions()
options := defaultClientOptions()
for _, opt := range opts {
opt(options)
}
return &Client{
mux: newMux(),
idGenerator: internal.NewIdGenerator(),
clientOpts: options,
opts: options,
lock: &sync.Mutex{},
conn: nil,
}
Expand Down Expand Up @@ -61,28 +61,30 @@ func (c *Client) OnUnknownEvent(handler Handler) *Client {
return c
}

func (c *Client) Dial(url string) error {
func (c *Client) Dial(ctx context.Context, url string) error {
c.lock.Lock()
defer c.lock.Unlock()

if c.conn != nil {
return errors.New("already connected")
}

ws, resp, err := websocket.Dial(context.Background(), url, c.clientOpts.dialOptions)
ws, resp, err := websocket.Dial(ctx, url, c.opts.dialOptions)
if err != nil {
return fmt.Errorf("websocket dial: %w", err)
}

ws.SetReadLimit(c.clientOpts.readLimit)
ws.SetReadLimit(c.opts.readLimit)

id := c.idGenerator.Next()
conn := newConn(id, resp.Request, ws)
c.conn = conn

keeper := newKeeper(conn, c.mux)
go func() {
defer conn.Close()
defer func() {
_ = ws.CloseNow()
}()

keeper.Serve(context.Background())

Expand Down
18 changes: 10 additions & 8 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,29 @@ import (
"nhooyr.io/websocket"
)

type ClientOption func(*ClientOptions)
type DialOptions = websocket.DialOptions

type ClientOptions struct {
dialOptions *websocket.DialOptions
type ClientOption func(*clientOptions)

type clientOptions struct {
dialOptions *DialOptions
readLimit int64
}

func DefaultClientOptions() *ClientOptions {
return &ClientOptions{
func defaultClientOptions() *clientOptions {
return &clientOptions{
readLimit: defaultReadLimit,
}
}

func WithClientDialOptions(opts *websocket.DialOptions) ClientOption {
return func(o *ClientOptions) {
func WithClientDialOptions(opts *DialOptions) ClientOption {
return func(o *clientOptions) {
o.dialOptions = opts
}
}

func WithClientReadLimit(limit int64) ClientOption {
return func(o *ClientOptions) {
return func(o *clientOptions) {
o.readLimit = limit
}
}
Loading

0 comments on commit f20a333

Please sign in to comment.