From 471b5cc7e26c6f3e8dff6d089ba61a2f8365a499 Mon Sep 17 00:00:00 2001 From: Jannis Mattheis Date: Fri, 27 Sep 2024 12:51:10 +0200 Subject: [PATCH] fix: goroutine leak --- ws/client.go | 76 +++++++++++++++++++--------------------- ws/event_connected.go | 8 +++++ ws/event_create.go | 1 - ws/event_disconnected.go | 30 +++++++++++----- ws/event_join.go | 1 - ws/outgoing/messages.go | 9 +++++ ws/room.go | 1 - ws/rooms.go | 15 ++++++-- 8 files changed, 87 insertions(+), 54 deletions(-) create mode 100644 ws/event_connected.go diff --git a/ws/client.go b/ws/client.go index 505ed64a..498a7882 100644 --- a/ws/client.go +++ b/ws/client.go @@ -34,8 +34,9 @@ type Client struct { } type ClientMessage struct { - Info ClientInfo - Incoming Event + Info ClientInfo + SkipConnectedCheck bool + Incoming Event } type ClientInfo struct { @@ -44,7 +45,6 @@ type ClientInfo struct { Authenticated bool AuthenticatedUser string Write chan outgoing.Message - Close chan string Addr net.IP } @@ -63,36 +63,46 @@ func newClient(conn *websocket.Conn, req *http.Request, read chan ClientMessage, RoomID: "", Addr: ip, Write: make(chan outgoing.Message, 1), - Close: make(chan string, 1), }, read: read, } client.debug().Msg("WebSocket New Connection") - conn.SetCloseHandler(func(code int, text string) error { - message := websocket.FormatCloseMessage(code, text) - client.debug().Str("reason", text).Int("code", code).Msg("WebSocket Close") - return conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) - }) return client } -// Close closes the connection. -func (c *Client) Close() { +// CloseOnError closes the connection. +func (c *Client) CloseOnError(code int, reason string) { c.once.Do(func() { - c.conn.Close() go func() { c.read <- ClientMessage{ - Info: c.info, - Incoming: &Disconnected{}, + Info: c.info, + Incoming: &Disconnected{ + Code: code, + Reason: reason, + }, } }() + c.writeCloseMessage(code, reason) }) } +func (c *Client) CloseOnDone(code int, reason string) { + c.once.Do(func() { + c.writeCloseMessage(code, reason) + }) +} + +func (c *Client) writeCloseMessage(code int, reason string) { + message := websocket.FormatCloseMessage(code, reason) + c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait)) + c.conn.Close() +} + // startWriteHandler starts listening on the client connection. As we do not need anything from the client, // we ignore incoming messages. Leaves the loop on errors. func (c *Client) startReading(pongWait time.Duration) { - defer c.Close() + defer c.CloseOnError(websocket.CloseNormalClosure, "Reader Routine Closed") + _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(appData string) error { _ = c.conn.SetReadDeadline(time.Now().Add(pongWait)) @@ -101,17 +111,17 @@ func (c *Client) startReading(pongWait time.Duration) { for { t, m, err := c.conn.NextReader() if err != nil { - c.printWebSocketError("read", err) + c.CloseOnError(websocket.CloseNormalClosure, "read error") return } if t == websocket.BinaryMessage { - _ = c.conn.CloseHandler()(websocket.CloseUnsupportedData, fmt.Sprintf("unsupported binary message type: %s", err)) + c.CloseOnError(websocket.CloseUnsupportedData, "unsupported binary message type") return } incoming, err := ReadTypedIncoming(m) if err != nil { - _ = c.conn.CloseHandler()(websocket.CloseNormalClosure, fmt.Sprintf("malformed message: %s", err)) + c.CloseOnError(websocket.CloseUnsupportedData, fmt.Sprintf("malformed message: %s", err)) return } c.debug().Interface("event", fmt.Sprintf("%T", incoming)).Interface("payload", incoming).Msg("WebSocket Receive") @@ -125,30 +135,18 @@ func (c *Client) startReading(pongWait time.Duration) { // * on errors exit the loop. func (c *Client) startWriteHandler(pingPeriod time.Duration) { pingTicker := time.NewTicker(pingPeriod) - - dead := false - conClosed := func() { - dead = true - c.Close() - pingTicker.Stop() - } - defer conClosed() + defer pingTicker.Stop() defer func() { c.debug().Msg("WebSocket Done") }() + defer c.conn.Close() for { select { - case reason := <-c.info.Close: - if reason == CloseDone { - return - } else { - _ = c.conn.CloseHandler()(websocket.CloseNormalClosure, reason) - conClosed() - } case message := <-c.info.Write: - if dead { - c.debug().Msg("WebSocket write on dead connection") - continue + if msg, ok := message.(outgoing.CloseWriter); ok { + c.debug().Str("reason", msg.Reason).Int("code", msg.Code).Msg("WebSocket Close") + c.CloseOnDone(msg.Code, msg.Reason) + return } _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) @@ -156,7 +154,7 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) { c.debug().Interface("event", typed.Type).Interface("payload", typed.Payload).Msg("WebSocket Send") if err != nil { c.debug().Err(err).Msg("could not get typed message, exiting connection.") - conClosed() + c.CloseOnError(websocket.CloseNormalClosure, "malformed outgoing "+err.Error()) continue } @@ -165,14 +163,14 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) { } if err := writeJSON(c.conn, typed); err != nil { - conClosed() c.printWebSocketError("write", err) + c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error()) } case <-pingTicker.C: _ = c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := ping(c.conn); err != nil { - conClosed() c.printWebSocketError("ping", err) + c.CloseOnError(websocket.CloseNormalClosure, "ping timeout") } } } diff --git a/ws/event_connected.go b/ws/event_connected.go new file mode 100644 index 00000000..6758f926 --- /dev/null +++ b/ws/event_connected.go @@ -0,0 +1,8 @@ +package ws + +type Connected struct{} + +func (e Connected) Execute(rooms *Rooms, current ClientInfo) error { + rooms.connected[current.ID] = true + return nil +} diff --git a/ws/event_create.go b/ws/event_create.go index 2a280780..ac9882fe 100644 --- a/ws/event_create.go +++ b/ws/event_create.go @@ -71,7 +71,6 @@ func (e *Create) Execute(rooms *Rooms, current ClientInfo) error { Owner: true, Addr: current.Addr, Write: current.Write, - Close: current.Close, }, }, } diff --git a/ws/event_disconnected.go b/ws/event_disconnected.go index f4488671..6c945393 100644 --- a/ws/event_disconnected.go +++ b/ws/event_disconnected.go @@ -3,30 +3,41 @@ package ws import ( "bytes" + "github.com/gorilla/websocket" "github.com/screego/server/ws/outgoing" ) -type Disconnected struct{} +type Disconnected struct { + Code int + Reason string +} func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error { + e.executeNoError(rooms, current) + return nil +} + +func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) { + delete(rooms.connected, current.ID) + current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason} + if current.RoomID == "" { - return nil + return } room, ok := rooms.Rooms[current.RoomID] if !ok { // room may already be removed - return nil + return } user, ok := room.Users[current.ID] if !ok { // room may already be removed - return nil + return } - current.Close <- CloseDone delete(room.Users, current.ID) usersLeftTotal.Inc() @@ -49,18 +60,19 @@ func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error { if user.Owner && room.CloseOnOwnerLeave { for _, member := range room.Users { - member.Close <- CloseOwnerLeft + delete(rooms.connected, member.ID) + member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft} } rooms.closeRoom(current.RoomID) - return nil + return } if len(room.Users) == 0 { rooms.closeRoom(current.RoomID) - return nil + return } room.notifyInfoChanged() - return nil + return } diff --git a/ws/event_join.go b/ws/event_join.go index 20daeac4..5930f327 100644 --- a/ws/event_join.go +++ b/ws/event_join.go @@ -39,7 +39,6 @@ func (e *Join) Execute(rooms *Rooms, current ClientInfo) error { Owner: false, Addr: current.Addr, Write: current.Write, - Close: current.Close, } room.notifyInfoChanged() usersJoinedTotal.Inc() diff --git a/ws/outgoing/messages.go b/ws/outgoing/messages.go index dd2179ce..bb66a6bc 100644 --- a/ws/outgoing/messages.go +++ b/ws/outgoing/messages.go @@ -96,3 +96,12 @@ const ( ConnectionSTUN ConnectionMode = "stun" ConnectionTURN ConnectionMode = "turn" ) + +type CloseWriter struct { + Code int + Reason string +} + +func (CloseWriter) Type() string { + return "closewriter" +} diff --git a/ws/room.go b/ws/room.go index 0b41ea72..0c8bbafa 100644 --- a/ws/room.go +++ b/ws/room.go @@ -136,5 +136,4 @@ type User struct { Streaming bool Owner bool Write chan<- outgoing.Message - Close chan<- string } diff --git a/ws/rooms.go b/ws/rooms.go index 20aee1d7..e6a86cf1 100644 --- a/ws/rooms.go +++ b/ws/rooms.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gorilla/websocket" + "github.com/rs/xid" "github.com/rs/zerolog/log" "github.com/screego/server/auth" "github.com/screego/server/config" @@ -19,6 +20,7 @@ func NewRooms(tServer turn.Server, users *auth.Users, conf config.Config) *Rooms return &Rooms{ Rooms: map[string]*Room{}, Incoming: make(chan ClientMessage), + connected: map[xid.ID]bool{}, turnServer: tServer, users: users, config: conf, @@ -49,6 +51,7 @@ type Rooms struct { users *auth.Users config config.Config r *rand.Rand + connected map[xid.ID]bool } func (r *Rooms) RandUserName() string { @@ -70,16 +73,22 @@ func (r *Rooms) Upgrade(w http.ResponseWriter, req *http.Request) { user, loggedIn := r.users.CurrentUser(req) c := newClient(conn, req, r.Incoming, user, loggedIn, r.config.TrustProxyHeaders) + r.Incoming <- ClientMessage{Info: c.info, Incoming: Connected{}, SkipConnectedCheck: true} go c.startReading(time.Second * 20) go c.startWriteHandler(time.Second * 5) } func (r *Rooms) Start() { - for { - msg := <-r.Incoming + for msg := range r.Incoming { + if !msg.SkipConnectedCheck && !r.connected[msg.Info.ID] { + log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore") + continue + } + if err := msg.Incoming.Execute(r, msg.Info); err != nil { - msg.Info.Close <- err.Error() + dis := Disconnected{Code: websocket.CloseNormalClosure, Reason: err.Error()} + dis.executeNoError(r, msg.Info) } } }