Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable compression by default and switch to stdlib compress #240

Merged
merged 5 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ go get nhooyr.io/websocket
- Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports)
- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports)
- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages
- Zero alloc reads and writes
- Concurrent writes
Expand Down Expand Up @@ -112,7 +112,6 @@ Advantages of nhooyr.io/websocket:
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203))
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))

Expand Down
2 changes: 1 addition & 1 deletion accept.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ type AcceptOptions struct {
OriginPatterns []string

// CompressionMode controls the compression mode.
// Defaults to CompressionNoContextTakeover.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
Expand Down
20 changes: 0 additions & 20 deletions accept_js.go

This file was deleted.

4 changes: 3 additions & 1 deletion accept_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ func TestAccept(t *testing.T) {
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")

_, err := Accept(w, r, nil)
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionContextTakeover,
})
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})

Expand Down
14 changes: 11 additions & 3 deletions autobahn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ var excludedAutobahnCases = []string{

// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
}

Expand All @@ -37,10 +36,17 @@ var autobahnCases = []string{"*"}
func TestAutobahn(t *testing.T) {
t.Parallel()

if os.Getenv("AUTOBAHN_TEST") == "" {
if os.Getenv("AUTOBAHN") == "" {
t.SkipNow()
}

if os.Getenv("AUTOBAHN") == "fast" {
// These are the slow tests.
excludedAutobahnCases = append(excludedAutobahnCases,
"9.*", "13.*", "12.*",
)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
defer cancel()

Expand All @@ -61,7 +67,9 @@ func TestAutobahn(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()

c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{
CompressionMode: websocket.CompressionContextTakeover,
})
assert.Success(t, err)
err = wstest.EchoLoop(ctx, c)
t.Logf("echoLoop: %v", err)
Expand Down
205 changes: 205 additions & 0 deletions close.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
// +build !js

package websocket

import (
"context"
"encoding/binary"
"errors"
"fmt"
"log"
"time"

"nhooyr.io/websocket/internal/errd"
)

// StatusCode represents a WebSocket status code.
Expand Down Expand Up @@ -74,3 +82,200 @@ func CloseStatus(err error) StatusCode {
}
return -1
}

// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
return c.closeHandshake(code, reason)
}

func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")

writeErr := c.writeClose(code, reason)
closeHandshakeErr := c.waitCloseHandshake()

if writeErr != nil {
return writeErr
}

if CloseStatus(closeHandshakeErr) == -1 {
return closeHandshakeErr
}

return nil
}

var errAlreadyWroteClose = errors.New("already wrote close")

func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return errAlreadyWroteClose
}

ce := CloseError{
Code: code,
Reason: reason,
}

var p []byte
var marshalErr error
if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes()
if marshalErr != nil {
log.Printf("websocket: %v", marshalErr)
}
}

writeErr := c.writeControl(context.Background(), opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
}

// We do this after in case there was an error writing the close frame.
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))

if marshalErr != nil {
return marshalErr
}
return writeErr
}

func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()

if c.readCloseFrameErr != nil {
return c.readCloseFrameErr
}

for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}

for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}

func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}

if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}

ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}

if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}

return ce, nil
}

// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}

if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}

return false
}

func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}

const maxCloseReason = maxControlPayload - 2

func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}

if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}

buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}

func (c *Conn) setCloseErr(err error) {
c.closeMu.Lock()
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}

func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
}

func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
Loading