Skip to content

Commit

Permalink
pool: Correct endpoint waitgroup logic.
Browse files Browse the repository at this point in the history
This corrects the waitgroup logic in endpoint handling to ensure it
properly terminates when the context is canceled.

It also enables the wait at the end of the endpoint test to help prove
correctness.
  • Loading branch information
davecgh committed Sep 14, 2023
1 parent f9ffa9e commit 91ecde3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 27 deletions.
40 changes: 17 additions & 23 deletions pool/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ type connection struct {
type Endpoint struct {
listenAddr string
connCh chan *connection
discCh chan struct{}
listener net.Listener
cfg *EndpointConfig
clients map[string]*Client
Expand All @@ -88,7 +87,6 @@ func NewEndpoint(eCfg *EndpointConfig, listenAddr string) (*Endpoint, error) {
cfg: eCfg,
clients: make(map[string]*Client),
connCh: make(chan *connection, bufferSize),
discCh: make(chan struct{}, bufferSize),
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
Expand All @@ -109,7 +107,9 @@ func (e *Endpoint) removeClient(c *Client) {

// listen accepts incoming client connections on the endpoint.
// It must be run as a goroutine.
func (e *Endpoint) listen() {
func (e *Endpoint) listen(ctx context.Context) {
defer e.wg.Done()

log.Infof("listening on %s", e.listenAddr)
for {
conn, err := e.listener.Accept()
Expand All @@ -126,16 +126,19 @@ func (e *Endpoint) listen() {
log.Errorf("unable to accept client connection: %v", err)
return
}
e.connCh <- &connection{
Conn: conn,
Done: make(chan bool),
select {
case <-ctx.Done():
return
case e.connCh <- &connection{Conn: conn, Done: make(chan bool)}:
}
}
}

// connect creates new pool clients from established connections.
// It must be run as a goroutine.
func (e *Endpoint) connect(ctx context.Context) {
defer e.wg.Done()

for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -203,22 +206,13 @@ func (e *Endpoint) connect(ctx context.Context) {
// disconnect relays client disconnections to the endpoint for processing.
// It must be run as a goroutine.
func (e *Endpoint) disconnect(ctx context.Context) {
for {
select {
case <-ctx.Done():
e.clientsMtx.Lock()
for _, client := range e.clients {
client.cancel()
}
e.clientsMtx.Unlock()

e.wg.Done()
return

case <-e.discCh:
e.wg.Done()
}
<-ctx.Done()
e.clientsMtx.Lock()
for _, client := range e.clients {
client.cancel()
}
e.clientsMtx.Unlock()
e.wg.Done()
}

// generateHashIDs generates hash ids of all client connections to the pool.
Expand All @@ -238,8 +232,8 @@ func (e *Endpoint) generateHashIDs() map[string]struct{} {
// run handles the lifecycle of all endpoint related processes.
// This should be run as a goroutine.
func (e *Endpoint) run(ctx context.Context) {
e.wg.Add(1)
go e.listen()
e.wg.Add(3)
go e.listen(ctx)
go e.connect(ctx)
go e.disconnect(ctx)
e.wg.Wait()
Expand Down
5 changes: 1 addition & 4 deletions pool/endpoint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ func testEndpoint(t *testing.T) {
t.Fatalf("[NewEndpoint] unexpected error: %v", err)
}
ctx, cancel := context.WithCancel(context.Background())
endpoint.wg.Add(1)
var wg sync.WaitGroup
wg.Add(1)
go func() {
Expand Down Expand Up @@ -235,7 +234,5 @@ func testEndpoint(t *testing.T) {
defer conn.Close()

cancel()
// TODO: This never finishes because endpoint.run never actually finishes
// due to the internal waitgroup not being handled properly.
// wg.Wait()
wg.Wait()
}

0 comments on commit 91ecde3

Please sign in to comment.