diff --git a/README.md b/README.md index f1fc5896..4199423c 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ websocket is a minimal and idiomatic WebSocket library for Go. ## Install ```bash -go get nhooyr.io/websocket +go get nhooyr.io/websocket@v1.0.0 ``` ## Features @@ -19,7 +19,7 @@ go get nhooyr.io/websocket - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages -- High performance, memory reuse by default +- Highly optimized by default - Concurrent writes out of the box ## Roadmap @@ -129,8 +129,9 @@ gorilla/websocket requires you to constantly read from the connection to respond even if you don't expect the peer to send any messages. In terms of performance, the differences depend on your application code. nhooyr/websocket -reuses buffers efficiently out of the box whereas gorilla/websocket does not. As mentioned -above, it also supports concurrent writers out of the box. +reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas +gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent +writers out of the box. The only performance con to nhooyr/websocket is that uses two extra goroutines. One for reading pings, pongs and close frames async to application code and another to support diff --git a/websocket.go b/websocket.go index e974002b..d59812b8 100644 --- a/websocket.go +++ b/websocket.go @@ -63,65 +63,6 @@ type Conn struct { activePings map[string]chan<- struct{} } -// Context returns a context derived from parent that will be cancelled -// when the connection is closed or broken. -// If the parent context is cancelled, the connection will be closed. -// -// This is an experimental API. -// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 -func (c *Conn) Context(parent context.Context) context.Context { - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case c.setConnContext <- parent: - } - - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case ctx := <-c.getConnContext: - return ctx - } -} - -func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) - - c.closeErr = xerrors.Errorf("websocket closed: %w", err) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.closer.Close() - - // See comment in dial.go - if c.client { - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readPayload and readHeader. - c.readFrameLock <- struct{}{} - returnBufioReader(c.br) - - c.writeFrameLock <- struct{}{} - returnBufioWriter(c.bw) - } - }) -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - func (c *Conn) init() { c.closed = make(chan struct{}) @@ -149,79 +90,38 @@ func (c *Conn) init() { go c.readLoop() } -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - err := c.writeFrame(ctx, true, opcode, p) - if err != nil { - return xerrors.Errorf("failed to write control frame: %w", err) - } - return nil +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol } -// writeFrame handles all writes to the connection. -// We never mask inside here because our mask key is always 0,0,0,0. -// See comment on secWebSocketKey for why. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { - h := header{ - fin: fin, - opcode: opcode, - masked: c.client, - payloadLength: int64(len(p)), - } - b2 := marshalHeader(h) - - err := c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-c.closed: - return c.closeErr - case c.setWriteTimeout <- ctx: - } - - writeErr := func(err error) error { - select { - case <-c.closed: - return c.closeErr - default: - } +func (c *Conn) close(err error) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) - err = xerrors.Errorf("failed to write to connection: %w", err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) + c.closeErr = xerrors.Errorf("websocket closed: %w", err) + close(c.closed) - return err - } + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.closer.Close() - _, err = c.bw.Write(b2) - if err != nil { - return writeErr(err) - } - _, err = c.bw.Write(p) - if err != nil { - return writeErr(err) - } + // See comment in dial.go + if c.client { + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readFramePayload and readHeader. + c.readFrameLock <- struct{}{} + returnBufioReader(c.br) - if fin { - err = c.bw.Flush() - if err != nil { - return writeErr(err) + c.writeFrameLock <- struct{}{} + returnBufioWriter(c.bw) } - } - - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return c.closeErr - case c.setWriteTimeout <- context.Background(): - } - - return nil + }) } func (c *Conn) timeoutLoop() { @@ -255,60 +155,84 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) handleControl(h header) { - if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return +// Context returns a context derived from parent that will be cancelled +// when the connection is closed or broken. +// If the parent context is cancelled, the connection will be closed. +// +// This is an experimental API. +// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 +func (c *Conn) Context(parent context.Context) context.Context { + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case c.setConnContext <- parent: } - if !h.fin { - c.Close(StatusProtocolError, "control frame cannot be fragmented") - return + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case ctx := <-c.getConnContext: + return ctx } +} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - b := make([]byte, h.payloadLength) - - _, err := c.readFramePayload(ctx, b) - if err != nil { - return +func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { + select { + case <-ctx.Done(): + var err error + switch lock { + case c.writeFrameLock, c.writeMsgLock: + err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) + case c.readFrameLock: + err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + default: + panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) + } + c.close(err) + return ctx.Err() + case <-c.closed: + return c.closeErr + case lock <- struct{}{}: + return nil } +} - if h.masked { - fastXOR(h.maskKey, 0, b) +func (c *Conn) releaseLock(lock chan struct{}) { + // Allow multiple releases. + select { + case <-lock: + default: } +} - switch h.opcode { - case opPing: - c.writePong(b) - case opPong: - c.activePingsMu.Lock() - pong, ok := c.activePings[string(b)] - c.activePingsMu.Unlock() - if ok { - close(pong) - } - case opClose: - ce, err := parseClosePayload(b) +func (c *Conn) readLoop() { + for { + h, err := c.readTillMsg() if err != nil { - c.close(xerrors.Errorf("received invalid close payload: %w", err)) return } - if ce.Code == StatusNoStatusRcvd { - c.writeClose(nil, ce) - } else { - c.Close(ce.Code, ce.Reason) + + select { + case <-c.closed: + return + case c.readMsg <- h: + } + + select { + case <-c.closed: + return + case <-c.readMsgDone: } - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } } func (c *Conn) readTillMsg() (header, error) { for { - h, err := c.readHeader() + h, err := c.readFrameHeader() if err != nil { return header{}, err } @@ -335,7 +259,7 @@ func (c *Conn) readTillMsg() (header, error) { } } -func (c *Conn) readHeader() (header, error) { +func (c *Conn) readFrameHeader() (header, error) { err := c.acquireLock(context.Background(), c.readFrameLock) if err != nil { return header{}, err @@ -353,119 +277,282 @@ func (c *Conn) readHeader() (header, error) { return h, nil } -func (c *Conn) readLoop() { - for { - h, err := c.readTillMsg() - if err != nil { - return - } +func (c *Conn) handleControl(h header) { + if h.payloadLength > maxControlFramePayload { + c.Close(StatusProtocolError, "control frame too large") + return + } - select { - case <-c.closed: - return - case c.readMsg <- h: - } - - select { - case <-c.closed: - return - case <-c.readMsgDone: - } + if !h.fin { + c.Close(StatusProtocolError, "control frame cannot be fragmented") + return } -} -func (c *Conn) writePong(p []byte) error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - err := c.writeControl(ctx, opPong, p) - return err + b := make([]byte, h.payloadLength) + + _, err := c.readFramePayload(ctx, b) + if err != nil { + return + } + + if h.masked { + fastXOR(h.maskKey, 0, b) + } + + switch h.opcode { + case opPing: + c.writePong(b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + close(pong) + } + case opClose: + ce, err := parseClosePayload(b) + if err != nil { + c.close(xerrors.Errorf("received invalid close payload: %w", err)) + return + } + if ce.Code == StatusNoStatusRcvd { + c.writeClose(nil, ce) + } else { + c.Close(ce.Code, ce.Reason) + } + default: + panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) + } } -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5 seconds. -// The connection can only be closed once. Additional calls to Close -// are no-ops. +// Reader waits until there is a WebSocket data message to read +// from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. // -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. +// Control (ping, pong, close) frames will be handled automatically +// in a separate goroutine so if you do not expect any data messages, +// you do not need to read from the connection. However, if the peer +// sends a data message, further pings, pongs and close frames will not +// be read if you do not read the message from the connection. // -// Close will unblock all goroutines interacting with the connection. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason) +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.reader(ctx) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + return 0, nil, xerrors.Errorf("failed to get reader: %w", err) } - return nil + return typ, &limitedReader{ + c: c, + r: r, + left: c.msgReadLimit, + }, nil } -func (c *Conn) exportedClose(code StatusCode, reason string) error { - ce := CloseError{ - Code: code, - Reason: reason, +func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { + if c.previousReader != nil && c.previousReader.h != nil { + // The only way we know for sure the previous reader is not yet complete is + // if there is an active frame not yet fully read. + // Otherwise, a user may have read the last byte but not the EOF if the EOF + // is in the next frame so we check for that below. + return 0, nil, xerrors.Errorf("previous message not read to completion") } - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) - ce = CloseError{ - Code: StatusInternalError, + select { + case <-c.closed: + return 0, nil, c.closeErr + case <-ctx.Done(): + return 0, nil, ctx.Err() + case h := <-c.readMsg: + if c.previousReader != nil && !c.previousReader.done { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } + + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") + } + + c.previousReader.done = true + + select { + case <-c.closed: + return 0, nil, c.closeErr + case c.readMsgDone <- struct{}{}: + } + + return c.reader(ctx) + } else if h.opcode == opContinuation { + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err } - p, _ = ce.bytes() + + r := &messageReader{ + ctx: ctx, + c: c, + + h: &h, + } + c.previousReader = r + return MessageType(h.opcode), r, nil } +} - return c.writeClose(p, ce) +// messageReader enables reading a data frame from the WebSocket connection. +type messageReader struct { + ctx context.Context + c *Conn + + h *header + maskPos int + done bool } -func (c *Conn) writeClose(p []byte, cerr CloseError) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +// Read reads as many bytes as possible into p. +func (r *messageReader) Read(p []byte) (int, error) { + n, err := r.read(p) + if err != nil { + // Have to return io.EOF directly for now, we cannot wrap as xerrors + // isn't used in stdlib. + if xerrors.Is(err, io.EOF) { + return n, io.EOF + } + return n, xerrors.Errorf("failed to read: %w", err) + } + return n, nil +} + +func (r *messageReader) read(p []byte) (int, error) { + if r.done { + return 0, xerrors.Errorf("cannot use EOFed reader") + } + + if r.h == nil { + select { + case <-r.c.closed: + return 0, r.c.closeErr + case <-r.ctx.Done(): + r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err())) + return 0, r.ctx.Err() + case h := <-r.c.readMsg: + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data frame without finishing the previous frame") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err + } + r.h = &h + } + } + + if int64(len(p)) > r.h.payloadLength { + p = p[:r.h.payloadLength] + } + + n, err := r.c.readFramePayload(r.ctx, p) + + r.h.payloadLength -= int64(n) + if r.h.masked { + r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + } - err := c.writeControl(ctx, opClose, p) if err != nil { - return err + return n, err } - c.close(cerr) - if !xerrors.Is(c.closeErr, cerr) { - return c.closeErr + if r.h.payloadLength == 0 { + select { + case <-r.c.closed: + return n, r.c.closeErr + case r.c.readMsgDone <- struct{}{}: + } + + fin := r.h.fin + + // Need to nil this as Reader uses it to check + // whether there is active data on the previous reader and + // now there isn't. + r.h = nil + + if fin { + r.done = true + return n, io.EOF + } + + r.maskPos = 0 } - return nil + return n, nil } -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return 0, err + } + defer c.releaseLock(c.readFrameLock) + select { - case <-ctx.Done(): - var err error - switch lock { - case c.writeFrameLock, c.writeMsgLock: - err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) - case c.readFrameLock: - err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + case <-c.closed: + return 0, c.closeErr + case c.setReadTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + err = ctx.Err() default: - panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) } + err = xerrors.Errorf("failed to read from connection: %w", err) + c.releaseLock(c.readFrameLock) c.close(err) - return ctx.Err() + return n, err + } + + select { case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil + return n, c.closeErr + case c.setReadTimeout <- context.Background(): } + + return n, err } -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. - select { - case <-lock: - default: +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusPolicyViolation. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit = n +} + +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method if you want to be able to reuse buffers or want to stream a message. +// The docs on Reader apply to this method as well. +// +// This is an experimental API, please let me know how you feel about it in +// https://github.com/nhooyr/websocket/issues/62 +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err } + + b, err := ioutil.ReadAll(r) + return typ, b, err } // Writer returns a writer bounded by the context that will write @@ -488,28 +575,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err if err != nil { return nil, err } - return &messageWriter{ - ctx: ctx, - opcode: opcode(typ), - c: c, - }, nil -} - -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this method as well. -// -// This is an experimental API, please let me know how you feel about it in -// https://github.com/nhooyr/websocket/issues/62 -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) - if err != nil { - return 0, nil, err - } - - b, err := ioutil.ReadAll(r) - return typ, b, err + return &messageWriter{ + ctx: ctx, + opcode: opcode(typ), + c: c, + }, nil } // Write is a convenience method to write a message to the connection. @@ -592,194 +662,146 @@ func (w *messageWriter) close() error { return nil } -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// Ensure you read to EOF otherwise the connection will hang. -// -// Control (ping, pong, close) frames will be handled automatically -// in a separate goroutine so if you do not expect any data messages, -// you do not need to read from the connection. However, if the peer -// sends a data message, further pings, pongs and close frames will not -// be read if you do not read the message from the connection. -// -// Only one Reader may be open at a time. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - typ, r, err := c.reader(ctx) +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + err := c.writeFrame(ctx, true, opcode, p) if err != nil { - return 0, nil, xerrors.Errorf("failed to get reader: %w", err) + return xerrors.Errorf("failed to write control frame: %w", err) } - return typ, &limitedReader{ - c: c, - r: r, - left: c.msgReadLimit, - }, nil + return nil } -func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.previousReader.h != nil && c.previousReader.h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") +// writeFrame handles all writes to the connection. +// We never mask inside here because our mask key is always 0,0,0,0. +// See comment on secWebSocketKey for why. +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { + h := header{ + fin: fin, + opcode: opcode, + masked: c.client, + payloadLength: int64(len(p)), + } + b2 := marshalHeader(h) + + err := c.acquireLock(ctx, c.writeFrameLock) + if err != nil { + return err } + defer c.releaseLock(c.writeFrameLock) select { case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case h := <-c.readMsg: - if c.previousReader != nil && !c.previousReader.done { - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, xerrors.Errorf("previous message not read to completion") - } + return c.closeErr + case c.setWriteTimeout <- ctx: + } - c.previousReader.done = true - return c.reader(ctx) - } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err - } - r := &messageReader{ - ctx: ctx, - h: &h, - c: c, + writeErr := func(err error) error { + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: } - c.previousReader = r - return MessageType(h.opcode), r, nil - } -} -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - ctx context.Context - c *Conn + err = xerrors.Errorf("failed to write to connection: %w", err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) - h *header - maskPos int - done bool -} + return err + } -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - n, err := r.read(p) + _, err = c.bw.Write(b2) if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as xerrors - // isn't used in stdlib. - if xerrors.Is(err, io.EOF) { - return n, io.EOF - } - err = xerrors.Errorf("failed to read: %w", err) - r.c.close(err) - return n, err + return writeErr(err) } - return n, nil -} - -func (r *messageReader) read(p []byte) (int, error) { - if r.done { - return 0, xerrors.Errorf("cannot use EOFed reader") + _, err = c.bw.Write(p) + if err != nil { + return writeErr(err) } - if r.h == nil { - select { - case <-r.c.closed: - return 0, r.c.closeErr - case <-r.ctx.Done(): - return 0, r.ctx.Err() - case h := <-r.c.readMsg: - if h.opcode != opContinuation { - err := xerrors.Errorf("received new data frame without finishing the previous frame") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err - } - r.h = &h + if fin { + err = c.bw.Flush() + if err != nil { + return writeErr(err) } } - if int64(len(p)) > r.h.payloadLength { - p = p[:r.h.payloadLength] + // We already finished writing, no need to potentially brick the connection if + // the context expires. + select { + case <-c.closed: + return c.closeErr + case c.setWriteTimeout <- context.Background(): } - n, err := r.c.readFramePayload(r.ctx, p) + return nil +} - r.h.payloadLength -= int64(n) - if r.h.masked { - r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) - } +func (c *Conn) writePong(p []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opPong, p) + return err +} +// Close closes the WebSocket connection with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5 seconds. +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes otherwise an internal +// error will be sent to the peer. For this reason, you should avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) if err != nil { - return n, err + return xerrors.Errorf("failed to close connection: %w", err) } + return nil +} - if r.h.payloadLength == 0 { - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.readMsgDone <- struct{}{}: - } +func (c *Conn) exportedClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } - if r.h.fin { - r.done = true - return n, io.EOF + // This function also will not wait for a close frame from the peer like the RFC + // wants because that makes no sense and I don't think anyone actually follows that. + // Definitely worth seeing what popular browsers do later. + p, err := ce.bytes() + if err != nil { + fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + ce = CloseError{ + Code: StatusInternalError, } - - r.maskPos = 0 - r.h = nil + p, _ = ce.bytes() } - return n, nil + return c.writeClose(p, ce) } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - err := c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- ctx: - } +func (c *Conn) writeClose(p []byte, cerr CloseError) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - n, err := io.ReadFull(c.br, p) + err := c.writeControl(ctx, opClose, p) if err != nil { - select { - case <-c.closed: - return n, c.closeErr - default: - } - err = xerrors.Errorf("failed to read from connection: %w", err) - c.releaseLock(c.readFrameLock) - c.close(err) - return n, err + return err } - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- context.Background(): + c.close(cerr) + if !xerrors.Is(c.closeErr, cerr) { + return c.closeErr } - return 0, err -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusPolicyViolation. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit = n + return nil } func init() { @@ -794,9 +816,7 @@ func init() { func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx) if err != nil { - err = xerrors.Errorf("failed to ping: %w", err) - c.close(err) - return err + return xerrors.Errorf("failed to ping: %w", err) } return nil } @@ -823,10 +843,11 @@ func (c *Conn) ping(ctx context.Context) error { } select { - case <-ctx.Done(): - return ctx.Err() case <-c.closed: return c.closeErr + case <-ctx.Done(): + c.close(xerrors.Errorf("failed to ping: %w", ctx.Err())) + return ctx.Err() case <-pong: return nil }