From 91ecde3a44853ac9ddfd005174717255c5969fe4 Mon Sep 17 00:00:00 2001 From: Dave Collins Date: Wed, 13 Sep 2023 15:33:40 -0500 Subject: [PATCH] pool: Correct endpoint waitgroup logic. This corrects the waitgroup logic in endpoint handling to ensure it properly terminates when the context is canceled. It also enables the wait at the end of the endpoint test to help prove correctness. --- pool/endpoint.go | 40 +++++++++++++++++----------------------- pool/endpoint_test.go | 5 +---- 2 files changed, 18 insertions(+), 27 deletions(-) diff --git a/pool/endpoint.go b/pool/endpoint.go index c596cfd7..7647a2d9 100644 --- a/pool/endpoint.go +++ b/pool/endpoint.go @@ -73,7 +73,6 @@ type connection struct { type Endpoint struct { listenAddr string connCh chan *connection - discCh chan struct{} listener net.Listener cfg *EndpointConfig clients map[string]*Client @@ -88,7 +87,6 @@ func NewEndpoint(eCfg *EndpointConfig, listenAddr string) (*Endpoint, error) { cfg: eCfg, clients: make(map[string]*Client), connCh: make(chan *connection, bufferSize), - discCh: make(chan struct{}, bufferSize), } listener, err := net.Listen("tcp", listenAddr) if err != nil { @@ -109,7 +107,9 @@ func (e *Endpoint) removeClient(c *Client) { // listen accepts incoming client connections on the endpoint. // It must be run as a goroutine. -func (e *Endpoint) listen() { +func (e *Endpoint) listen(ctx context.Context) { + defer e.wg.Done() + log.Infof("listening on %s", e.listenAddr) for { conn, err := e.listener.Accept() @@ -126,9 +126,10 @@ func (e *Endpoint) listen() { log.Errorf("unable to accept client connection: %v", err) return } - e.connCh <- &connection{ - Conn: conn, - Done: make(chan bool), + select { + case <-ctx.Done(): + return + case e.connCh <- &connection{Conn: conn, Done: make(chan bool)}: } } } @@ -136,6 +137,8 @@ func (e *Endpoint) listen() { // connect creates new pool clients from established connections. // It must be run as a goroutine. func (e *Endpoint) connect(ctx context.Context) { + defer e.wg.Done() + for { select { case <-ctx.Done(): @@ -203,22 +206,13 @@ func (e *Endpoint) connect(ctx context.Context) { // disconnect relays client disconnections to the endpoint for processing. // It must be run as a goroutine. func (e *Endpoint) disconnect(ctx context.Context) { - for { - select { - case <-ctx.Done(): - e.clientsMtx.Lock() - for _, client := range e.clients { - client.cancel() - } - e.clientsMtx.Unlock() - - e.wg.Done() - return - - case <-e.discCh: - e.wg.Done() - } + <-ctx.Done() + e.clientsMtx.Lock() + for _, client := range e.clients { + client.cancel() } + e.clientsMtx.Unlock() + e.wg.Done() } // generateHashIDs generates hash ids of all client connections to the pool. @@ -238,8 +232,8 @@ func (e *Endpoint) generateHashIDs() map[string]struct{} { // run handles the lifecycle of all endpoint related processes. // This should be run as a goroutine. func (e *Endpoint) run(ctx context.Context) { - e.wg.Add(1) - go e.listen() + e.wg.Add(3) + go e.listen(ctx) go e.connect(ctx) go e.disconnect(ctx) e.wg.Wait() diff --git a/pool/endpoint_test.go b/pool/endpoint_test.go index 353b4697..3510343e 100644 --- a/pool/endpoint_test.go +++ b/pool/endpoint_test.go @@ -83,7 +83,6 @@ func testEndpoint(t *testing.T) { t.Fatalf("[NewEndpoint] unexpected error: %v", err) } ctx, cancel := context.WithCancel(context.Background()) - endpoint.wg.Add(1) var wg sync.WaitGroup wg.Add(1) go func() { @@ -235,7 +234,5 @@ func testEndpoint(t *testing.T) { defer conn.Close() cancel() - // TODO: This never finishes because endpoint.run never actually finishes - // due to the internal waitgroup not being handled properly. - // wg.Wait() + wg.Wait() }