diff --git a/example_test.go b/example_test.go index 050af907..0b59e6a0 100644 --- a/example_test.go +++ b/example_test.go @@ -74,11 +74,7 @@ func Example_writeOnly() { ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10) defer cancel() - go func() { - defer cancel() - c.Reader(ctx) - c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages") - }() + ctx = c.CloseRead(ctx) t := time.NewTicker(time.Second * 30) defer t.Stop() diff --git a/websocket.go b/websocket.go index bc90415d..91a6808f 100644 --- a/websocket.go +++ b/websocket.go @@ -22,7 +22,7 @@ import ( // and SetReadLimit. // // You must always read from the connection. Otherwise control -// frames will not be handled. See the docs on Reader. +// frames will not be handled. See the docs on Reader and CloseRead. // // Please be sure to call Close on the connection when you // are finished with it to release the associated resources. @@ -319,10 +319,8 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // to be closed so you do not need to write your own error message. // This applies to the Read methods in the wsjson/wspb subpackages as well. // -// You must read from the connection for close frames to be read. -// If you do not expect any data messages from the peer, just call -// Reader in a separate goroutine and close the connection with StatusPolicyViolation -// when it returns. See the writeOnly example. +// You must read from the connection for control frames to be handled. +// If you do not expect any data messages from the peer, call CloseRead. // // Only one Reader may be open at a time. // @@ -388,6 +386,21 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return MessageType(h.opcode), r, nil } +// CloseRead will close the connection if any data message is received from the peer. +// Call this when you are done reading data messages from the connection but will still write +// to it. Since CloseRead is still reading from the connection, it will respond to ping, pong +// and close frames automatically. It will only close the connection on a data frame. The returned +// context will be cancelled when the connection is closed. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { c *Conn diff --git a/websocket_test.go b/websocket_test.go index 993ff9ab..2d7db271 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -74,6 +74,50 @@ func TestHandshake(t *testing.T) { return nil }, }, + { + name: "closeError", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wsjson.Write(r.Context(), c, "hello") + if err != nil { + return err + } + + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + Subprotocols: []string{"meow"}, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + var m string + err = wsjson.Read(ctx, c, &m) + if err != nil { + return err + } + + if m != "hello" { + return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) + } + + _, _, err = c.Reader(ctx) + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("unexpected error: %+v", err) + } + + return nil + }, + }, { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error {