diff --git a/pool/endpoint_test.go b/pool/endpoint_test.go index feb9b929..76f8783b 100644 --- a/pool/endpoint_test.go +++ b/pool/endpoint_test.go @@ -28,6 +28,7 @@ func makeConn(listener *net.TCPListener, serverCh chan net.Conn) (net.Conn, net. } func testEndpoint(t *testing.T) { + const maxConnsPerHost = 3 powLimit := chaincfg.SimNetParams().PowLimit iterations := math.Pow(2, float64(256-powLimit.BitLen())) maxGenTime := time.Second * 20 @@ -36,14 +37,14 @@ func testEndpoint(t *testing.T) { new(big.Rat).SetInt(powLimit), maxGenTime) connections := make(map[string]uint32) var connectionsMtx sync.RWMutex - var connectionsWg sync.WaitGroup + removeConn := make(chan struct{}, maxConnsPerHost) eCfg := &EndpointConfig{ ActiveNet: chaincfg.SimNetParams(), db: db, SoloPool: true, Blake256Pad: blake256Pad, NonceIterations: iterations, - MaxConnectionsPerHost: 3, + MaxConnectionsPerHost: maxConnsPerHost, FetchMinerDifficulty: func(miner string) (*DifficultyInfo, error) { return poolDiffs.fetchMinerDifficulty(miner) }, @@ -57,7 +58,6 @@ func testEndpoint(t *testing.T) { return true }, AddConnection: func(host string) { - connectionsWg.Add(1) connectionsMtx.Lock() connections[host]++ connectionsMtx.Unlock() @@ -66,7 +66,7 @@ func testEndpoint(t *testing.T) { connectionsMtx.Lock() connections[host]-- connectionsMtx.Unlock() - connectionsWg.Done() + removeConn <- struct{}{} }, FetchHostConnections: func(host string) uint32 { connectionsMtx.RLock() @@ -91,7 +91,18 @@ func testEndpoint(t *testing.T) { endpoint.run(ctx) wg.Done() }() - time.Sleep(time.Millisecond * 100) + sendToConnChanOrFatal := func(msg *connection) { + select { + case endpoint.connCh <- msg: + case <-ctx.Done(): + t.Fatalf("unexpected endpoint shutdown") + } + select { + case <-msg.Done: + case <-ctx.Done(): + t.Fatalf("unexpected endpoint shutdown") + } + } laddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3031") if err != nil { @@ -134,8 +145,7 @@ func testEndpoint(t *testing.T) { Conn: connA, Done: make(chan struct{}), } - endpoint.connCh <- msgA - <-msgA.Done + sendToConnChanOrFatal(msgA) addr := connA.RemoteAddr() tcpAddr, err := net.ResolveTCPAddr(addr.Network(), addr.String()) if err != nil { @@ -161,8 +171,7 @@ func testEndpoint(t *testing.T) { Conn: connB, Done: make(chan struct{}), } - endpoint.connCh <- msgB - <-msgB.Done + sendToConnChanOrFatal(msgB) connC, srvC, err := makeConn(ln, serverCh) if err != nil { t.Fatalf("unexpected error: %v", err) @@ -173,8 +182,7 @@ func testEndpoint(t *testing.T) { Conn: connC, Done: make(chan struct{}), } - endpoint.connCh <- msgC - <-msgC.Done + sendToConnChanOrFatal(msgC) // Ensure the connected clients to the host got incremented to 3. hostConnections = endpoint.cfg.FetchHostConnections(host) @@ -194,8 +202,7 @@ func testEndpoint(t *testing.T) { Conn: connD, Done: make(chan struct{}), } - endpoint.connCh <- msgD - <-msgD.Done + sendToConnChanOrFatal(msgD) // Ensure the connected clients count to the host stayed at 3 because // the recent connection got rejected due to MaxConnectionCountPerHost @@ -206,7 +213,7 @@ func testEndpoint(t *testing.T) { "for host %s, got %d", 3, host, hostConnections) } - // Remove all clients. + // Remove all clients and wait for their removal. endpoint.clientsMtx.Lock() clients := make([]*Client, 0, len(endpoint.clients)) for _, cl := range endpoint.clients { @@ -216,7 +223,13 @@ func testEndpoint(t *testing.T) { for _, cl := range clients { cl.shutdown() } - connectionsWg.Wait() + for i := 0; i < len(clients); i++ { + select { + case <-removeConn: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for connection removal") + } + } // Ensure there are no connected clients to the host. hostConnections = endpoint.cfg.FetchHostConnections(host)