Skip to content

Commit

Permalink
Merge pull request #99 from nhooyr/closeread
Browse files Browse the repository at this point in the history
Add CloseRead and closeError test
  • Loading branch information
nhooyr authored Jun 23, 2019
2 parents 3149225 + 6eda9c5 commit 176b144
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 10 deletions.
6 changes: 1 addition & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ func Example_writeOnly() {
ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10)
defer cancel()

go func() {
defer cancel()
c.Reader(ctx)
c.Close(websocket.StatusPolicyViolation, "server doesn't accept data messages")
}()
ctx = c.CloseRead(ctx)

t := time.NewTicker(time.Second * 30)
defer t.Stop()
Expand Down
23 changes: 18 additions & 5 deletions websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
// and SetReadLimit.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See the docs on Reader.
// frames will not be handled. See the docs on Reader and CloseRead.
//
// Please be sure to call Close on the connection when you
// are finished with it to release the associated resources.
Expand Down Expand Up @@ -319,10 +319,8 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
// You must read from the connection for close frames to be read.
// If you do not expect any data messages from the peer, just call
// Reader in a separate goroutine and close the connection with StatusPolicyViolation
// when it returns. See the writeOnly example.
// You must read from the connection for control frames to be handled.
// If you do not expect any data messages from the peer, call CloseRead.
//
// Only one Reader may be open at a time.
//
Expand Down Expand Up @@ -388,6 +386,21 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
return MessageType(h.opcode), r, nil
}

// CloseRead will close the connection if any data message is received from the peer.
// Call this when you are done reading data messages from the connection but will still write
// to it. Since CloseRead is still reading from the connection, it will respond to ping, pong
// and close frames automatically. It will only close the connection on a data frame. The returned
// context will be cancelled when the connection is closed.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
c.Reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}

// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
c *Conn
Expand Down
44 changes: 44 additions & 0 deletions websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,50 @@ func TestHandshake(t *testing.T) {
return nil
},
},
{
name: "closeError",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")

err = wsjson.Write(r.Context(), c, "hello")
if err != nil {
return err
}

return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")

var m string
err = wsjson.Read(ctx, c, &m)
if err != nil {
return err
}

if m != "hello" {
return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
}

_, _, err = c.Reader(ctx)
var cerr websocket.CloseError
if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
return xerrors.Errorf("unexpected error: %+v", err)
}

return nil
},
},
{
name: "defaultSubprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
Expand Down

0 comments on commit 176b144

Please sign in to comment.