From b453d3ee377e513da7fdf5f9241e0be088489717 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 May 2020 02:47:53 -0400 Subject: [PATCH 1/5] Move all Wasm related code into ws_js.go This way we don't pollute the directory tree. --- accept_js.go | 20 ---- close.go | 205 +++++++++++++++++++++++++++++++++++ close_notjs.go | 211 ------------------------------------ compress.go | 180 +++++++++++++++++++++++++++++++ compress_notjs.go | 181 ------------------------------- conn.go | 264 +++++++++++++++++++++++++++++++++++++++++++++ conn_notjs.go | 265 ---------------------------------------------- ws_js.go | 134 +++++++++++++++++++++++ 8 files changed, 783 insertions(+), 677 deletions(-) delete mode 100644 accept_js.go delete mode 100644 close_notjs.go delete mode 100644 compress_notjs.go delete mode 100644 conn_notjs.go diff --git a/accept_js.go b/accept_js.go deleted file mode 100644 index daad4b79..00000000 --- a/accept_js.go +++ /dev/null @@ -1,20 +0,0 @@ -package websocket - -import ( - "errors" - "net/http" -) - -// AcceptOptions represents Accept's options. -type AcceptOptions struct { - Subprotocols []string - InsecureSkipVerify bool - OriginPatterns []string - CompressionMode CompressionMode - CompressionThreshold int -} - -// Accept is stubbed out for Wasm. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - return nil, errors.New("unimplemented") -} diff --git a/close.go b/close.go index 7cbc19e9..d76dc2f4 100644 --- a/close.go +++ b/close.go @@ -1,8 +1,16 @@ +// +build !js + package websocket import ( + "context" + "encoding/binary" "errors" "fmt" + "log" + "time" + + "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -74,3 +82,200 @@ func CloseStatus(err error) StatusCode { } return -1 } + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr + } + + if CloseStatus(closeHandshakeErr) == -1 { + return closeHandshakeErr + } + + return nil +} + +var errAlreadyWroteClose = errors.New("already wrote close") + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + wroteClose := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if wroteClose { + return errAlreadyWroteClose + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + var p []byte + var marshalErr error + if ce.Code != StatusNoStatusRcvd { + p, marshalErr = ce.bytes() + if marshalErr != nil { + log.Printf("websocket: %v", marshalErr) + } + } + + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr + } + return writeErr +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.readMu.lock(ctx) + if err != nil { + return err + } + defer c.readMu.unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = fmt.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_notjs.go b/close_notjs.go deleted file mode 100644 index 4251311d..00000000 --- a/close_notjs.go +++ /dev/null @@ -1,211 +0,0 @@ -// +build !js - -package websocket - -import ( - "context" - "encoding/binary" - "errors" - "fmt" - "log" - "time" - - "nhooyr.io/websocket/internal/errd" -) - -// Close performs the WebSocket close handshake with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// All data messages received from the peer during the close handshake will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - return c.closeHandshake(code, reason) -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - writeErr := c.writeClose(code, reason) - closeHandshakeErr := c.waitCloseHandshake() - - if writeErr != nil { - return writeErr - } - - if CloseStatus(closeHandshakeErr) == -1 { - return closeHandshakeErr - } - - return nil -} - -var errAlreadyWroteClose = errors.New("already wrote close") - -func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - wroteClose := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if wroteClose { - return errAlreadyWroteClose - } - - ce := CloseError{ - Code: code, - Reason: reason, - } - - var p []byte - var marshalErr error - if ce.Code != StatusNoStatusRcvd { - p, marshalErr = ce.bytes() - if marshalErr != nil { - log.Printf("websocket: %v", marshalErr) - } - } - - writeErr := c.writeControl(context.Background(), opClose, p) - if CloseStatus(writeErr) != -1 { - // Not a real error if it's due to a close frame being received. - writeErr = nil - } - - // We do this after in case there was an error writing the close frame. - c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) - - if marshalErr != nil { - return marshalErr - } - return writeErr -} - -func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.readMu.lock(ctx) - if err != nil { - return err - } - defer c.readMu.unlock() - - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - - for { - h, err := c.readLoop(ctx) - if err != nil { - return err - } - - for i := int64(0); i < h.payloadLength; i++ { - _, err := c.br.ReadByte() - if err != nil { - return err - } - } - } -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -func (ce CloseError) bytes() ([]byte, error) { - p, err := ce.bytesErr() - if err != nil { - err = fmt.Errorf("failed to marshal close frame: %w", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytesErr() - } - return p, err -} - -const maxCloseReason = maxControlPayload - 2 - -func (ce CloseError) bytesErr() ([]byte, error) { - if len(ce.Reason) > maxCloseReason { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) - } - - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { - c.closeErr = fmt.Errorf("WebSocket closed: %w", err) - } -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/compress.go b/compress.go index 80b46d1c..63d961b4 100644 --- a/compress.go +++ b/compress.go @@ -1,5 +1,15 @@ +// +build !js + package websocket +import ( + "io" + "net/http" + "sync" + + "github.com/klauspost/compress/flate" +) + // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 // @@ -37,3 +47,173 @@ const ( // important than bandwidth. CompressionDisabled ) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + if n == 0 { + n = 32768 + } + + p := slidingWindowPool(n) + buf, ok := p.Get().([]byte) + if ok { + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPoolMu.Lock() + swPool[cap(sw.buf)].Put(sw.buf) + swPoolMu.Unlock() + sw.buf = nil +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/compress_notjs.go b/compress_notjs.go deleted file mode 100644 index 809a272c..00000000 --- a/compress_notjs.go +++ /dev/null @@ -1,181 +0,0 @@ -// +build !js - -package websocket - -import ( - "io" - "net/http" - "sync" - - "github.com/klauspost/compress/flate" -) - -func (m CompressionMode) opts() *compressionOptions { - return &compressionOptions{ - clientNoContextTakeover: m == CompressionNoContextTakeover, - serverNoContextTakeover: m == CompressionNoContextTakeover, - } -} - -type compressionOptions struct { - clientNoContextTakeover bool - serverNoContextTakeover bool -} - -func (copts *compressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.clientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.serverNoContextTakeover { - s += "; server_no_context_takeover" - } - h.Set("Sec-WebSocket-Extensions", s) -} - -// These bytes are required to get flate.Reader to return. -// They are removed when sending to avoid the overhead as -// WebSocket framing tell's when the message has ended but then -// we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. -const deflateMessageTail = "\x00\x00\xff\xff" - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte -} - -func (tw *trimLastFourBytesWriter) reset() { - if tw != nil && tw.tail != nil { - tw.tail = tw.tail[:0] - } -} - -func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { - if tw.tail == nil { - tw.tail = make([]byte, 0, 4) - } - - extra := len(tw.tail) + len(p) - 4 - - if extra <= 0 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(tw.tail) { - extra = len(tw.tail) - } - if extra > 0 { - _, err := tw.w.Write(tw.tail[:extra]) - if err != nil { - return 0, err - } - - // Shift remaining bytes in tail over. - n := copy(tw.tail, tw.tail[extra:]) - tw.tail = tw.tail[:n] - } - - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Otherwise, only the last 4 bytes are. - tw.tail = append(tw.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := tw.w.Write(p) - return n + 4, err -} - -var flateReaderPool sync.Pool - -func getFlateReader(r io.Reader, dict []byte) io.Reader { - fr, ok := flateReaderPool.Get().(io.Reader) - if !ok { - return flate.NewReaderDict(r, dict) - } - fr.(flate.Resetter).Reset(r, dict) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -type slidingWindow struct { - buf []byte -} - -var swPoolMu sync.RWMutex -var swPool = map[int]*sync.Pool{} - -func slidingWindowPool(n int) *sync.Pool { - swPoolMu.RLock() - p, ok := swPool[n] - swPoolMu.RUnlock() - if ok { - return p - } - - p = &sync.Pool{} - - swPoolMu.Lock() - swPool[n] = p - swPoolMu.Unlock() - - return p -} - -func (sw *slidingWindow) init(n int) { - if sw.buf != nil { - return - } - - if n == 0 { - n = 32768 - } - - p := slidingWindowPool(n) - buf, ok := p.Get().([]byte) - if ok { - sw.buf = buf[:0] - } else { - sw.buf = make([]byte, 0, n) - } -} - -func (sw *slidingWindow) close() { - if sw.buf == nil { - return - } - - swPoolMu.Lock() - swPool[cap(sw.buf)].Put(sw.buf) - swPoolMu.Unlock() - sw.buf = nil -} - -func (sw *slidingWindow) write(p []byte) { - if len(p) >= cap(sw.buf) { - sw.buf = sw.buf[:cap(sw.buf)] - p = p[len(p)-cap(sw.buf):] - copy(sw.buf, p) - return - } - - left := cap(sw.buf) - len(sw.buf) - if left < len(p) { - // We need to shift spaceNeeded bytes from the end to make room for p at the end. - spaceNeeded := len(p) - left - copy(sw.buf, sw.buf[spaceNeeded:]) - sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] - } - - sw.buf = append(sw.buf, p...) -} diff --git a/conn.go b/conn.go index a41808be..e208d116 100644 --- a/conn.go +++ b/conn.go @@ -1,5 +1,19 @@ +// +build !js + package websocket +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int @@ -11,3 +25,253 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriterState = newMsgWriterState(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriterState.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(errors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + c.setCloseErrLocked(err) + close(c.closed) + runtime.SetFinalizer(c, nil) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + c.msgWriterState.close() + + c.msgReader.close() + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, errors.New("timed out")) + case <-writeCtx.Done(): + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return m.c.closeErr + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/conn_notjs.go b/conn_notjs.go deleted file mode 100644 index bb2eb22f..00000000 --- a/conn_notjs.go +++ /dev/null @@ -1,265 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "runtime" - "strconv" - "sync" - "sync/atomic" -) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - - // Read state. - readMu *mu - readHeaderBuf [8]byte - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error - - // Write state. - msgWriterState *msgWriterState - writeFrameMu *mu - writeBuf []byte - writeHeaderBuf [8]byte - writeHeader header - - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriterState = newMsgWriterState(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 128 - if !c.msgWriterState.flateContextTakeover() { - c.flateThreshold = 512 - } - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(errors.New("connection garbage collected")) - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return - } - c.setCloseErrLocked(err) - close(c.closed) - runtime.SetFinalizer(c, nil) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.rwc.Close() - - go func() { - c.msgWriterState.close() - - c.msgReader.close() - }() -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, errors.New("timed out")) - case <-writeCtx.Done(): - c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) - - err := c.ping(ctx, strconv.Itoa(int(p))) - if err != nil { - return fmt.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} -} - -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } -} - -func (m *mu) forceLock() { - m.ch <- struct{}{} -} - -func (m *mu) lock(ctx context.Context) error { - select { - case <-m.c.closed: - return m.c.closeErr - case <-ctx.Done(): - err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err - case m.ch <- struct{}{}: - // To make sure the connection is certainly alive. - // As it's possible the send on m.ch was selected - // over the receive on closed. - select { - case <-m.c.closed: - // Make sure to release. - m.unlock() - return m.c.closeErr - default: - } - return nil - } -} - -func (m *mu) unlock() { - select { - case <-m.ch: - default: - } -} diff --git a/ws_js.go b/ws_js.go index b87e32cd..31e3c2f6 100644 --- a/ws_js.go +++ b/ws_js.go @@ -377,3 +377,137 @@ func (c *Conn) isClosed() bool { return false } } + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + OriginPatterns []string + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, errors.New("unimplemented") +} + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) From 17cf0fe86c9c23e64714986b266a15fd9a26142d Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 May 2020 03:12:08 -0400 Subject: [PATCH 2/5] Disable compression by default Closes #220 and #230 --- README.md | 3 +-- accept.go | 2 +- accept_test.go | 4 +++- autobahn_test.go | 4 +++- compress.go | 60 +++++++++++++++++++++++++++++------------------- conn_test.go | 4 ++-- dial.go | 2 +- go.mod | 1 - go.sum | 2 -- write.go | 35 ++++++++++++++++++---------- 10 files changed, 71 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index df20c581..8420bdbd 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ go get nhooyr.io/websocket - Minimal and idiomatic API - First class [context.Context](https://blog.golang.org/context) support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) -- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Zero alloc reads and writes - Concurrent writes @@ -112,7 +112,6 @@ Advantages of nhooyr.io/websocket: - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) - [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) diff --git a/accept.go b/accept.go index 66379b5d..f038dec9 100644 --- a/accept.go +++ b/accept.go @@ -51,7 +51,7 @@ type AcceptOptions struct { OriginPatterns []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode diff --git a/accept_test.go b/accept_test.go index 9b18d8e1..f7bc6693 100644 --- a/accept_test.go +++ b/accept_test.go @@ -55,7 +55,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") - _, err := Accept(w, r, nil) + _, err := Accept(w, r, &AcceptOptions{ + CompressionMode: CompressionContextTakeover, + }) assert.Contains(t, err, `unsupported permessage-deflate parameter`) }) diff --git a/autobahn_test.go b/autobahn_test.go index e56a4912..d53159a0 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -61,7 +61,9 @@ func TestAutobahn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{ + CompressionMode: websocket.CompressionContextTakeover, + }) assert.Success(t, err) err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) diff --git a/compress.go b/compress.go index 63d961b4..f49d9e5d 100644 --- a/compress.go +++ b/compress.go @@ -3,49 +3,47 @@ package websocket import ( + "compress/flate" "io" "net/http" "sync" - - "github.com/klauspost/compress/flate" ) // CompressionMode represents the modes available to the deflate extension. // See https://tools.ietf.org/html/rfc7692 -// -// A compatibility layer is implemented for the older deflate-frame extension used -// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 -// It will work the same in every way except that we cannot signal to the peer we -// want to use no context takeover on our side, we can only signal that they should. -// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 type CompressionMode int const ( - // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed - // for every message. This applies to both server and client side. + // CompressionDisabled disables the deflate extension. // - // This means less efficient compression as the sliding window from previous messages - // will not be used but the memory overhead will be lower if the connections - // are long lived and seldom used. + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. // - // The message will only be compressed if greater than 512 bytes. - CompressionNoContextTakeover CompressionMode = iota + // This is the default. + CompressionDisabled CompressionMode = iota - // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. - // This enables reusing the sliding window from previous messages. + // CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection. + // It reusing the sliding window from previous messages. // As most WebSocket protocols are repetitive, this can be very efficient. - // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover. + // + // Sometime in the future it will carry 65 kB overhead instead once https://github.com/golang/go/issues/36919 + // is fixed. // // If the peer negotiates NoContextTakeover on the client or server side, it will be // used instead as this is required by the RFC. CompressionContextTakeover - // CompressionDisabled disables the deflate extension. + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. // - // Use this if you are using a predominantly binary protocol with very - // little duplication in between messages or CPU and memory are more - // important than bandwidth. - CompressionDisabled + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover ) func (m CompressionMode) opts() *compressionOptions { @@ -146,6 +144,22 @@ func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} + type slidingWindow struct { buf []byte } diff --git a/conn_test.go b/conn_test.go index c2c41292..4bab5adf 100644 --- a/conn_test.go +++ b/conn_test.go @@ -37,7 +37,7 @@ func TestConn(t *testing.T) { t.Parallel() compressionMode := func() websocket.CompressionMode { - return websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)) + return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1)) } for i := 0; i < 5; i++ { @@ -389,7 +389,7 @@ func BenchmarkConn(b *testing.B) { mode: websocket.CompressionDisabled, }, { - name: "compress", + name: "compressContextTakeover", mode: websocket.CompressionContextTakeover, }, { diff --git a/dial.go b/dial.go index 2b25e351..9ec90444 100644 --- a/dial.go +++ b/dial.go @@ -35,7 +35,7 @@ type DialOptions struct { Subprotocols []string // CompressionMode controls the compression mode. - // Defaults to CompressionNoContextTakeover. + // Defaults to CompressionDisabled. // // See docs on CompressionMode for details. CompressionMode CompressionMode diff --git a/go.mod b/go.mod index c5f1a20f..d4bca923 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,5 @@ require ( github.com/golang/protobuf v1.3.5 github.com/google/go-cmp v0.4.0 github.com/gorilla/websocket v1.4.1 - github.com/klauspost/compress v1.10.3 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index 155c3013..1344e958 100644 --- a/go.sum +++ b/go.sum @@ -29,8 +29,6 @@ github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvK github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= -github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= -github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= diff --git a/write.go b/write.go index 2210cf81..b1c57c1b 100644 --- a/write.go +++ b/write.go @@ -12,7 +12,7 @@ import ( "io" "time" - "github.com/klauspost/compress/flate" + "compress/flate" "nhooyr.io/websocket/internal/errd" ) @@ -76,8 +76,8 @@ type msgWriterState struct { opcode opcode flate bool - trimWriter *trimLastFourBytesWriter - dict slidingWindow + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } func newMsgWriterState(c *Conn) *msgWriterState { @@ -96,7 +96,9 @@ func (mw *msgWriterState) ensureFlate() { } } - mw.dict.init(8192) + if mw.flateWriter == nil { + mw.flateWriter = getFlateWriter(mw.trimWriter) + } mw.flate = true } @@ -153,6 +155,13 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { return nil } +func (mw *msgWriterState) putFlateWriter() { + if mw.flateWriter != nil { + putFlateWriter(mw.flateWriter) + mw.flateWriter = nil + } +} + // Write writes the given bytes to the WebSocket connection. func (mw *msgWriterState) Write(p []byte) (_ int, err error) { err = mw.writeMu.lock(mw.ctx) @@ -177,12 +186,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { } if mw.flate { - err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) - if err != nil { - return 0, err - } - mw.dict.write(p) - return len(p), nil + return mw.flateWriter.Write(p) } return mw.write(p) @@ -207,13 +211,20 @@ func (mw *msgWriterState) Close() (err error) { } defer mw.writeMu.unlock() + if mw.flate { + err = mw.flateWriter.Flush() + if err != nil { + return fmt.Errorf("failed to flush flate: %w", err) + } + } + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } if mw.flate && !mw.flateContextTakeover() { - mw.dict.close() + mw.putFlateWriter() } mw.mu.unlock() return nil @@ -226,7 +237,7 @@ func (mw *msgWriterState) close() { } mw.writeMu.forceLock() - mw.dict.close() + mw.putFlateWriter() } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { From de8e29bdb753bc55c8f742c664adb44833afbc50 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 May 2020 04:25:52 -0400 Subject: [PATCH 3/5] Fix tests taking too long and switch to t.Cleanup --- autobahn_test.go | 7 ++++++- conn_test.go | 47 +++++++++++++---------------------------------- 2 files changed, 19 insertions(+), 35 deletions(-) diff --git a/autobahn_test.go b/autobahn_test.go index d53159a0..5bf0062c 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -28,7 +28,6 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 - // Same with klauspost/compress which doesn't allow adjusting the sliding window size. "13.3.*", "13.4.*", "13.5.*", "13.6.*", } @@ -41,6 +40,12 @@ func TestAutobahn(t *testing.T) { t.SkipNow() } + if os.Getenv("AUTOBAHN_FAST") != "" { + excludedAutobahnCases = append(excludedAutobahnCases, + "9.*", "13.*", "12.*", + ) + } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() diff --git a/conn_test.go b/conn_test.go index 4bab5adf..9c85459e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -49,7 +49,6 @@ func TestConn(t *testing.T) { CompressionMode: compressionMode(), CompressionThreshold: xrand.Int(9999), }) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -67,8 +66,9 @@ func TestConn(t *testing.T) { }) t.Run("badClose", func(t *testing.T) { - tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() + tt, c1, c2 := newConnTest(t, nil, nil) + + c2.CloseRead(tt.ctx) err := c1.Close(-1, "") assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") @@ -76,7 +76,6 @@ func TestConn(t *testing.T) { t.Run("ping", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c1.CloseRead(tt.ctx) c2.CloseRead(tt.ctx) @@ -92,7 +91,6 @@ func TestConn(t *testing.T) { t.Run("badPing", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() c2.CloseRead(tt.ctx) @@ -105,7 +103,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goDiscardLoop(c2) @@ -138,7 +135,6 @@ func TestConn(t *testing.T) { t.Run("concurrentWriteError", func(t *testing.T) { tt, c1, _ := newConnTest(t, nil, nil) - defer tt.cleanup() _, err := c1.Writer(tt.ctx, websocket.MessageText) assert.Success(t, err) @@ -152,7 +148,6 @@ func TestConn(t *testing.T) { t.Run("netConn", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary) @@ -192,17 +187,14 @@ func TestConn(t *testing.T) { t.Run("netConn/BadMsg", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary) n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText) + c2.CloseRead(tt.ctx) errs := xsync.Go(func() error { _, err := n2.Write([]byte("hello")) - if err != nil { - return err - } - return nil + return err }) _, err := ioutil.ReadAll(n1) @@ -218,7 +210,6 @@ func TestConn(t *testing.T) { t.Run("wsjson", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -248,7 +239,6 @@ func TestConn(t *testing.T) { t.Run("wspb", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) - defer tt.cleanup() tt.goEchoLoop(c2) @@ -305,8 +295,6 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error { type connTest struct { t testing.TB ctx context.Context - - doneFuncs []func() } func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) { @@ -317,30 +305,22 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) tt = &connTest{t: t, ctx: ctx} - tt.appendDone(cancel) + t.Cleanup(cancel) c1, c2 = wstest.Pipe(dialOpts, acceptOpts) if xrand.Bool() { c1, c2 = c2, c1 } - tt.appendDone(func() { - c2.Close(websocket.StatusInternalError, "") - c1.Close(websocket.StatusInternalError, "") + t.Cleanup(func() { + // We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid + // blocking the test shutting down. + go c2.Close(websocket.StatusInternalError, "") + go c1.Close(websocket.StatusInternalError, "") }) return tt, c1, c2 } -func (tt *connTest) appendDone(f func()) { - tt.doneFuncs = append(tt.doneFuncs, f) -} - -func (tt *connTest) cleanup() { - for i := len(tt.doneFuncs) - 1; i >= 0; i-- { - tt.doneFuncs[i]() - } -} - func (tt *connTest) goEchoLoop(c *websocket.Conn) { ctx, cancel := context.WithCancel(tt.ctx) @@ -348,7 +328,7 @@ func (tt *connTest) goEchoLoop(c *websocket.Conn) { err := wstest.EchoLoop(ctx, c) return assertCloseStatus(websocket.StatusNormalClosure, err) }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-echoLoopErr if err != nil { @@ -370,7 +350,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) { } } }) - tt.appendDone(func() { + tt.t.Cleanup(func() { cancel() err := <-discardLoopErr if err != nil { @@ -404,7 +384,6 @@ func BenchmarkConn(b *testing.B) { }, &websocket.AcceptOptions{ CompressionMode: bc.mode, }) - defer bb.cleanup() bb.goEchoLoop(c2) From 169521697c04f5b5a06b3da51bf4cad56884d2b6 Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 May 2020 14:09:47 -0400 Subject: [PATCH 4/5] Add ping example Closes #227 --- autobahn_test.go | 5 +++-- example_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/autobahn_test.go b/autobahn_test.go index 5bf0062c..7c735a38 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -36,11 +36,12 @@ var autobahnCases = []string{"*"} func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_TEST") == "" { + if os.Getenv("AUTOBAHN") == "" { t.SkipNow() } - if os.Getenv("AUTOBAHN_FAST") != "" { + if os.Getenv("AUTOBAHN") == "fast" { + // These are the slow tests. excludedAutobahnCases = append(excludedAutobahnCases, "9.*", "13.*", "12.*", ) diff --git a/example_test.go b/example_test.go index 632c4d6e..d44bd537 100644 --- a/example_test.go +++ b/example_test.go @@ -135,6 +135,31 @@ func Example_crossOrigin() { log.Fatal(err) } +func ExampleConn_Ping() { + // Dials a server and pings it 5 times. + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) + if err != nil { + log.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + // Required to read the Pongs from the server. + ctx = c.CloseRead(ctx) + + for i := 0; i < 5; i++ { + err = c.Ping(ctx) + if err != nil { + log.Fatal(err) + } + } + + c.Close(websocket.StatusNormalClosure, "") +} + // This example demonstrates how to create a WebSocket server // that gracefully exits when sent a signal. // From 0a61ffe87a498f8ff9fef8020bee799cfa4f927f Mon Sep 17 00:00:00 2001 From: Anmol Sethi Date: Mon, 18 May 2020 19:09:38 -0400 Subject: [PATCH 5/5] Make SetDeadline on NetConn not always close Conn NetConn has to close the connection to interrupt in progress reads and writes. However, it can block reads and writes that occur after the deadline instead of closing the connection. Closes #228 --- conn.go | 9 ++++ netconn.go | 128 +++++++++++++++++++++++++++++++++++------------------ ws_js.go | 32 ++++++++++++++ 3 files changed, 126 insertions(+), 43 deletions(-) diff --git a/conn.go b/conn.go index e208d116..1a57c656 100644 --- a/conn.go +++ b/conn.go @@ -246,6 +246,15 @@ func (m *mu) forceLock() { m.ch <- struct{}{} } +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + func (m *mu) lock(ctx context.Context) error { select { case <-m.c.closed: diff --git a/netconn.go b/netconn.go index 64aadf0b..ae04b20a 100644 --- a/netconn.go +++ b/netconn.go @@ -6,7 +6,7 @@ import ( "io" "math" "net" - "sync" + "sync/atomic" "time" ) @@ -28,9 +28,10 @@ import ( // // Close will close the *websocket.Conn with StatusNormalClosure. // -// When a deadline is hit, the connection will be closed. This is -// different from most net.Conn implementations where only the -// reading/writing goroutines are interrupted but the connection is kept alive. +// When a deadline is hit and there is an active read or write goroutine, the +// connection will be closed. This is different from most net.Conn implementations +// where only the reading/writing goroutines are interrupted but the connection +// is kept alive. // // The Addr methods will return a mock net.Addr that returns "websocket" for Network // and "websocket/unknown-addr" for String. @@ -41,17 +42,43 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { nc := &netConn{ c: c, msgType: msgType, + readMu: newMu(c), + writeMu: newMu(c), } - var cancel context.CancelFunc - nc.writeContext, cancel = context.WithCancel(ctx) - nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + var writeCancel context.CancelFunc + nc.writeCtx, writeCancel = context.WithCancel(ctx) + var readCancel context.CancelFunc + nc.readCtx, readCancel = context.WithCancel(ctx) + + nc.writeTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.writeMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active write goroutine and so we should cancel the context. + writeCancel() + return + } + defer nc.writeMu.unlock() + + // Prevents future writes from writing until the deadline is reset. + atomic.StoreInt64(&nc.writeExpired, 1) + }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C } - nc.readContext, cancel = context.WithCancel(ctx) - nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + nc.readTimer = time.AfterFunc(math.MaxInt64, func() { + if !nc.readMu.tryLock() { + // If the lock cannot be acquired, then there is an + // active read goroutine and so we should cancel the context. + readCancel() + return + } + defer nc.readMu.unlock() + + // Prevents future reads from reading until the deadline is reset. + atomic.StoreInt64(&nc.readExpired, 1) + }) if !nc.readTimer.Stop() { <-nc.readTimer.C } @@ -64,59 +91,72 @@ type netConn struct { msgType MessageType writeTimer *time.Timer - writeContext context.Context + writeMu *mu + writeExpired int64 + writeCtx context.Context readTimer *time.Timer - readContext context.Context - - readMu sync.Mutex - eofed bool - reader io.Reader + readMu *mu + readExpired int64 + readCtx context.Context + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} -func (c *netConn) Close() error { - return c.c.Close(StatusNormalClosure, "") +func (nc *netConn) Close() error { + return nc.c.Close(StatusNormalClosure, "") } -func (c *netConn) Write(p []byte) (int, error) { - err := c.c.Write(c.writeContext, c.msgType, p) +func (nc *netConn) Write(p []byte) (int, error) { + nc.writeMu.forceLock() + defer nc.writeMu.unlock() + + if atomic.LoadInt64(&nc.writeExpired) == 1 { + return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) + } + + err := nc.c.Write(nc.writeCtx, nc.msgType, p) if err != nil { return 0, err } return len(p), nil } -func (c *netConn) Read(p []byte) (int, error) { - c.readMu.Lock() - defer c.readMu.Unlock() +func (nc *netConn) Read(p []byte) (int, error) { + nc.readMu.forceLock() + defer nc.readMu.unlock() + + if atomic.LoadInt64(&nc.readExpired) == 1 { + return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) + } - if c.eofed { + if nc.readEOFed { return 0, io.EOF } - if c.reader == nil { - typ, r, err := c.c.Reader(c.readContext) + if nc.reader == nil { + typ, r, err := nc.c.Reader(nc.readCtx) if err != nil { switch CloseStatus(err) { case StatusNormalClosure, StatusGoingAway: - c.eofed = true + nc.readEOFed = true return 0, io.EOF } return 0, err } - if typ != c.msgType { - err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) - c.c.Close(StatusUnsupportedData, err.Error()) + if typ != nc.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ) + nc.c.Close(StatusUnsupportedData, err.Error()) return 0, err } - c.reader = r + nc.reader = r } - n, err := c.reader.Read(p) + n, err := nc.reader.Read(p) if err == io.EOF { - c.reader = nil + nc.reader = nil err = nil } return n, err @@ -133,34 +173,36 @@ func (a websocketAddr) String() string { return "websocket/unknown-addr" } -func (c *netConn) RemoteAddr() net.Addr { +func (nc *netConn) RemoteAddr() net.Addr { return websocketAddr{} } -func (c *netConn) LocalAddr() net.Addr { +func (nc *netConn) LocalAddr() net.Addr { return websocketAddr{} } -func (c *netConn) SetDeadline(t time.Time) error { - c.SetWriteDeadline(t) - c.SetReadDeadline(t) +func (nc *netConn) SetDeadline(t time.Time) error { + nc.SetWriteDeadline(t) + nc.SetReadDeadline(t) return nil } -func (c *netConn) SetWriteDeadline(t time.Time) error { +func (nc *netConn) SetWriteDeadline(t time.Time) error { + atomic.StoreInt64(&nc.writeExpired, 0) if t.IsZero() { - c.writeTimer.Stop() + nc.writeTimer.Stop() } else { - c.writeTimer.Reset(t.Sub(time.Now())) + nc.writeTimer.Reset(t.Sub(time.Now())) } return nil } -func (c *netConn) SetReadDeadline(t time.Time) error { +func (nc *netConn) SetReadDeadline(t time.Time) error { + atomic.StoreInt64(&nc.readExpired, 0) if t.IsZero() { - c.readTimer.Stop() + nc.readTimer.Stop() } else { - c.readTimer.Reset(t.Sub(time.Now())) + nc.readTimer.Reset(t.Sub(time.Now())) } return nil } diff --git a/ws_js.go b/ws_js.go index 31e3c2f6..d1361328 100644 --- a/ws_js.go +++ b/ws_js.go @@ -511,3 +511,35 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) tryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +}