Skip to content

Commit

Permalink
fix: goroutine leak
Browse files Browse the repository at this point in the history
  • Loading branch information
jmattheis committed Sep 27, 2024
1 parent 4df4737 commit 471b5cc
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 54 deletions.
76 changes: 37 additions & 39 deletions ws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ type Client struct {
}

type ClientMessage struct {
Info ClientInfo
Incoming Event
Info ClientInfo
SkipConnectedCheck bool
Incoming Event
}

type ClientInfo struct {
Expand All @@ -44,7 +45,6 @@ type ClientInfo struct {
Authenticated bool
AuthenticatedUser string
Write chan outgoing.Message
Close chan string
Addr net.IP
}

Expand All @@ -63,36 +63,46 @@ func newClient(conn *websocket.Conn, req *http.Request, read chan ClientMessage,
RoomID: "",
Addr: ip,
Write: make(chan outgoing.Message, 1),
Close: make(chan string, 1),
},
read: read,
}
client.debug().Msg("WebSocket New Connection")
conn.SetCloseHandler(func(code int, text string) error {
message := websocket.FormatCloseMessage(code, text)
client.debug().Str("reason", text).Int("code", code).Msg("WebSocket Close")
return conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
})
return client
}

// Close closes the connection.
func (c *Client) Close() {
// CloseOnError closes the connection.
func (c *Client) CloseOnError(code int, reason string) {
c.once.Do(func() {
c.conn.Close()
go func() {
c.read <- ClientMessage{
Info: c.info,
Incoming: &Disconnected{},
Info: c.info,
Incoming: &Disconnected{
Code: code,
Reason: reason,
},
}
}()
c.writeCloseMessage(code, reason)
})
}

func (c *Client) CloseOnDone(code int, reason string) {
c.once.Do(func() {
c.writeCloseMessage(code, reason)
})
}

func (c *Client) writeCloseMessage(code int, reason string) {
message := websocket.FormatCloseMessage(code, reason)
c.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(writeWait))
c.conn.Close()
}

// startWriteHandler starts listening on the client connection. As we do not need anything from the client,
// we ignore incoming messages. Leaves the loop on errors.
func (c *Client) startReading(pongWait time.Duration) {
defer c.Close()
defer c.CloseOnError(websocket.CloseNormalClosure, "Reader Routine Closed")

_ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(appData string) error {
_ = c.conn.SetReadDeadline(time.Now().Add(pongWait))
Expand All @@ -101,17 +111,17 @@ func (c *Client) startReading(pongWait time.Duration) {
for {
t, m, err := c.conn.NextReader()
if err != nil {
c.printWebSocketError("read", err)
c.CloseOnError(websocket.CloseNormalClosure, "read error")
return
}
if t == websocket.BinaryMessage {
_ = c.conn.CloseHandler()(websocket.CloseUnsupportedData, fmt.Sprintf("unsupported binary message type: %s", err))
c.CloseOnError(websocket.CloseUnsupportedData, "unsupported binary message type")
return
}

incoming, err := ReadTypedIncoming(m)
if err != nil {
_ = c.conn.CloseHandler()(websocket.CloseNormalClosure, fmt.Sprintf("malformed message: %s", err))
c.CloseOnError(websocket.CloseUnsupportedData, fmt.Sprintf("malformed message: %s", err))
return
}
c.debug().Interface("event", fmt.Sprintf("%T", incoming)).Interface("payload", incoming).Msg("WebSocket Receive")
Expand All @@ -125,38 +135,26 @@ func (c *Client) startReading(pongWait time.Duration) {
// * on errors exit the loop.
func (c *Client) startWriteHandler(pingPeriod time.Duration) {
pingTicker := time.NewTicker(pingPeriod)

dead := false
conClosed := func() {
dead = true
c.Close()
pingTicker.Stop()
}
defer conClosed()
defer pingTicker.Stop()
defer func() {
c.debug().Msg("WebSocket Done")
}()
defer c.conn.Close()
for {
select {
case reason := <-c.info.Close:
if reason == CloseDone {
return
} else {
_ = c.conn.CloseHandler()(websocket.CloseNormalClosure, reason)
conClosed()
}
case message := <-c.info.Write:
if dead {
c.debug().Msg("WebSocket write on dead connection")
continue
if msg, ok := message.(outgoing.CloseWriter); ok {
c.debug().Str("reason", msg.Reason).Int("code", msg.Code).Msg("WebSocket Close")
c.CloseOnDone(msg.Code, msg.Reason)
return
}

_ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
typed, err := ToTypedOutgoing(message)
c.debug().Interface("event", typed.Type).Interface("payload", typed.Payload).Msg("WebSocket Send")
if err != nil {
c.debug().Err(err).Msg("could not get typed message, exiting connection.")
conClosed()
c.CloseOnError(websocket.CloseNormalClosure, "malformed outgoing "+err.Error())
continue
}

Expand All @@ -165,14 +163,14 @@ func (c *Client) startWriteHandler(pingPeriod time.Duration) {
}

if err := writeJSON(c.conn, typed); err != nil {
conClosed()
c.printWebSocketError("write", err)
c.CloseOnError(websocket.CloseNormalClosure, "write error"+err.Error())
}
case <-pingTicker.C:
_ = c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := ping(c.conn); err != nil {
conClosed()
c.printWebSocketError("ping", err)
c.CloseOnError(websocket.CloseNormalClosure, "ping timeout")
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions ws/event_connected.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package ws

type Connected struct{}

func (e Connected) Execute(rooms *Rooms, current ClientInfo) error {
rooms.connected[current.ID] = true
return nil
}
1 change: 0 additions & 1 deletion ws/event_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ func (e *Create) Execute(rooms *Rooms, current ClientInfo) error {
Owner: true,
Addr: current.Addr,
Write: current.Write,
Close: current.Close,
},
},
}
Expand Down
30 changes: 21 additions & 9 deletions ws/event_disconnected.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,41 @@ package ws
import (
"bytes"

"github.com/gorilla/websocket"
"github.com/screego/server/ws/outgoing"
)

type Disconnected struct{}
type Disconnected struct {
Code int
Reason string
}

func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error {
e.executeNoError(rooms, current)
return nil
}

func (e *Disconnected) executeNoError(rooms *Rooms, current ClientInfo) {
delete(rooms.connected, current.ID)
current.Write <- outgoing.CloseWriter{Code: e.Code, Reason: e.Reason}

if current.RoomID == "" {
return nil
return
}

room, ok := rooms.Rooms[current.RoomID]
if !ok {
// room may already be removed
return nil
return
}

user, ok := room.Users[current.ID]

if !ok {
// room may already be removed
return nil
return
}

current.Close <- CloseDone
delete(room.Users, current.ID)
usersLeftTotal.Inc()

Expand All @@ -49,18 +60,19 @@ func (e *Disconnected) Execute(rooms *Rooms, current ClientInfo) error {

if user.Owner && room.CloseOnOwnerLeave {
for _, member := range room.Users {
member.Close <- CloseOwnerLeft
delete(rooms.connected, member.ID)
member.Write <- outgoing.CloseWriter{Code: websocket.CloseNormalClosure, Reason: CloseOwnerLeft}
}
rooms.closeRoom(current.RoomID)
return nil
return
}

if len(room.Users) == 0 {
rooms.closeRoom(current.RoomID)
return nil
return
}

room.notifyInfoChanged()

return nil
return
}
1 change: 0 additions & 1 deletion ws/event_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ func (e *Join) Execute(rooms *Rooms, current ClientInfo) error {
Owner: false,
Addr: current.Addr,
Write: current.Write,
Close: current.Close,
}
room.notifyInfoChanged()
usersJoinedTotal.Inc()
Expand Down
9 changes: 9 additions & 0 deletions ws/outgoing/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,12 @@ const (
ConnectionSTUN ConnectionMode = "stun"
ConnectionTURN ConnectionMode = "turn"
)

type CloseWriter struct {
Code int
Reason string
}

func (CloseWriter) Type() string {
return "closewriter"
}
1 change: 0 additions & 1 deletion ws/room.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,5 +136,4 @@ type User struct {
Streaming bool
Owner bool
Write chan<- outgoing.Message
Close chan<- string
}
15 changes: 12 additions & 3 deletions ws/rooms.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/gorilla/websocket"
"github.com/rs/xid"
"github.com/rs/zerolog/log"
"github.com/screego/server/auth"
"github.com/screego/server/config"
Expand All @@ -19,6 +20,7 @@ func NewRooms(tServer turn.Server, users *auth.Users, conf config.Config) *Rooms
return &Rooms{
Rooms: map[string]*Room{},
Incoming: make(chan ClientMessage),
connected: map[xid.ID]bool{},
turnServer: tServer,
users: users,
config: conf,
Expand Down Expand Up @@ -49,6 +51,7 @@ type Rooms struct {
users *auth.Users
config config.Config
r *rand.Rand
connected map[xid.ID]bool
}

func (r *Rooms) RandUserName() string {
Expand All @@ -70,16 +73,22 @@ func (r *Rooms) Upgrade(w http.ResponseWriter, req *http.Request) {

user, loggedIn := r.users.CurrentUser(req)
c := newClient(conn, req, r.Incoming, user, loggedIn, r.config.TrustProxyHeaders)
r.Incoming <- ClientMessage{Info: c.info, Incoming: Connected{}, SkipConnectedCheck: true}

go c.startReading(time.Second * 20)
go c.startWriteHandler(time.Second * 5)
}

func (r *Rooms) Start() {
for {
msg := <-r.Incoming
for msg := range r.Incoming {
if !msg.SkipConnectedCheck && !r.connected[msg.Info.ID] {
log.Debug().Interface("event", fmt.Sprintf("%T", msg.Incoming)).Interface("payload", msg.Incoming).Msg("WebSocket Ignore")
continue
}

if err := msg.Incoming.Execute(r, msg.Info); err != nil {
msg.Info.Close <- err.Error()
dis := Disconnected{Code: websocket.CloseNormalClosure, Reason: err.Error()}
dis.executeNoError(r, msg.Info)
}
}
}
Expand Down

0 comments on commit 471b5cc

Please sign in to comment.