diff --git a/syncx/batcher.go b/syncx/batcher.go index 57cb5c5..93eb8c5 100644 --- a/syncx/batcher.go +++ b/syncx/batcher.go @@ -12,10 +12,12 @@ type Batcher[T any] struct { process func(batch []T) maxItems int maxAge time.Duration - wg *sync.WaitGroup - buffer chan T - stop chan bool - batch []T + + wg *sync.WaitGroup + buffer chan T + stop chan bool + batch []T + timeout <-chan time.Time } // NewBatcher creates a new batcher. @@ -27,6 +29,8 @@ func NewBatcher[T any](process func(batch []T), maxItems int, maxAge time.Durati wg: wg, buffer: make(chan T, capacity), stop: make(chan bool), + batch: make([]T, 0, maxItems), + timeout: time.After(maxAge), } } @@ -45,21 +49,12 @@ func (b *Batcher[T]) Start() { b.flush() } - case <-time.After(b.maxAge): + case <-b.timeout: b.flush() case <-b.stop: - for len(b.buffer) > 0 || len(b.batch) > 0 { - buffSize := len(b.buffer) - canRead := min(b.maxItems-len(b.batch), buffSize) - - for i := 0; i < canRead; i++ { - v := <-b.buffer - b.batch = append(b.batch, v) - } - - b.flush() - } + b.drain() + close(b.buffer) return } } @@ -83,6 +78,21 @@ func (b *Batcher[T]) flush() { if len(b.batch) > 0 { b.process(b.batch) b.batch = make([]T, 0, b.maxItems) + b.timeout = time.After(b.maxAge) + } +} + +func (b *Batcher[T]) drain() { + for len(b.buffer) > 0 || len(b.batch) > 0 { + buffSize := len(b.buffer) + canRead := min(b.maxItems-len(b.batch), buffSize) + + for i := 0; i < canRead; i++ { + v := <-b.buffer + b.batch = append(b.batch, v) + } + + b.flush() } } diff --git a/syncx/batcher_test.go b/syncx/batcher_test.go index c47f5e2..8933a56 100644 --- a/syncx/batcher_test.go +++ b/syncx/batcher_test.go @@ -54,4 +54,9 @@ func TestBatcher(t *testing.T) { wg.Wait() assert.Equal(t, [][]int{{1, 2}, {3, 4}, {5}, {6, 7}, {8}}, batches) + + // panic if you try to queue to a stopped batcher + assert.Panics(t, func() { + b.Queue(9) + }) }