diff --git a/README.md b/README.md index acf0b37..8410bc5 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ import ( func main() { ctx, cancel := donegroup.WithCancel(context.Background()) - if err := donegroup.Clenup(ctx, func() error { + if err := donegroup.Clenup(ctx, func(_ context.Context) error { // Cleanup process of some kind fmt.Println("cleanup") return nil diff --git a/donegroup.go b/donegroup.go index 2744689..c7bc319 100644 --- a/donegroup.go +++ b/donegroup.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "time" "golang.org/x/sync/errgroup" ) @@ -11,6 +12,7 @@ import ( var doneGroupKey = struct{}{} type doneGroup struct { + ctx context.Context cleanupGroups []*errgroup.Group mu sync.Mutex } @@ -34,21 +36,21 @@ func WithCancelWithKey(ctx context.Context, key any) (context.Context, context.C } // Clenup runs f when the context is canceled. -func Clenup(ctx context.Context, f func() error) error { +func Clenup(ctx context.Context, f func(ctx context.Context) error) error { return ClenupWithKey(ctx, doneGroupKey, f) } // ClenupWithKey runs f when the context is canceled. -func ClenupWithKey(ctx context.Context, key any, f func() error) error { +func ClenupWithKey(ctx context.Context, key any, f func(ctx context.Context) error) error { dg, ok := ctx.Value(key).(*doneGroup) if !ok { - return errors.New("donegroup: context does not contain a donegroup. Use donegroup.WithCancel to create a context with a donegroup") + return errors.New("donegroup: context does not contain a doneGroup. Use donegroup.WithCancel to create a context with a doneGroup") } first := dg.cleanupGroups[0] first.Go(func() error { <-ctx.Done() - return f() + return dg.goWithCtx(f) }) return nil } @@ -58,16 +60,44 @@ func Wait(ctx context.Context) error { return WaitWithKey(ctx, doneGroupKey) } +// Wait blocks until the context is canceled or the timeout is reached. +func WaitWithTimeout(ctx context.Context, timeout time.Duration) error { + return WaitWithKeyAndTimeout(ctx, doneGroupKey, timeout) +} + // WaitWithKey blocks until the context is canceled. func WaitWithKey(ctx context.Context, key any) error { - <-ctx.Done() + return WaitWithKeyAndTimeout(ctx, key, 0) +} + +// WaitWithKeyAndTimeout blocks until the context is canceled or the timeout is reached. +func WaitWithKeyAndTimeout(ctx context.Context, key any, timeout time.Duration) error { dg, ok := ctx.Value(key).(*doneGroup) if !ok { - return errors.New("donegroup: context does not contain a donegroup. Use donegroup.WithCancel to create a context with a donegroup") + return errors.New("donegroup: context does not contain a doneGroup. Use donegroup.WithCancel to create a context with a doneGroup") } - eg := new(errgroup.Group) + ctxx := context.Background() + var cancel context.CancelFunc + if timeout != 0 { + ctxx, cancel = context.WithTimeout(ctxx, timeout) + defer cancel() + } + dg.mu.Lock() + dg.ctx = ctxx + dg.mu.Unlock() + + <-ctx.Done() + eg, _ := errgroup.WithContext(ctxx) for _, g := range dg.cleanupGroups { eg.Go(g.Wait) } + return eg.Wait() } + +func (dg *doneGroup) goWithCtx(f func(ctx context.Context) error) error { + dg.mu.Lock() + ctx := dg.ctx + dg.mu.Unlock() + return f(ctx) +} diff --git a/donegroup_test.go b/donegroup_test.go index 807b467..bd15276 100644 --- a/donegroup_test.go +++ b/donegroup_test.go @@ -2,6 +2,7 @@ package donegroup import ( "context" + "errors" "sync" "testing" "time" @@ -13,7 +14,7 @@ func TestDoneGroup(t *testing.T) { cleanup := false - if err := Clenup(ctx, func() error { + if err := Clenup(ctx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) cleanup = true return nil @@ -44,7 +45,7 @@ func TestMultiCleanup(t *testing.T) { cleanup := 0 for i := 0; i < 10; i++ { - if err := Clenup(ctx, func() error { + if err := Clenup(ctx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -78,7 +79,7 @@ func TestNested(t *testing.T) { secondCleanup := 0 for i := 0; i < 10; i++ { - if err := Clenup(firstCtx, func() error { + if err := Clenup(firstCtx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -90,7 +91,7 @@ func TestNested(t *testing.T) { } for i := 0; i < 5; i++ { - if err := Clenup(secondCtx, func() error { + if err := Clenup(secondCtx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -137,7 +138,7 @@ func TestRootWaitAll(t *testing.T) { leafCleanup := 0 for i := 0; i < 10; i++ { - if err := Clenup(rootCtx, func() error { + if err := Clenup(rootCtx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -149,7 +150,7 @@ func TestRootWaitAll(t *testing.T) { } for i := 0; i < 5; i++ { - if err := Clenup(leafCtx, func() error { + if err := Clenup(leafCtx, func(_ context.Context) error { time.Sleep(10 * time.Millisecond) mu.Lock() defer mu.Unlock() @@ -180,3 +181,32 @@ func TestRootWaitAll(t *testing.T) { } }() } + +func TestWaitWithTimeout(t *testing.T) { + t.Parallel() + ctx, cancel := WithCancel(context.Background()) + + if err := Clenup(ctx, func(ctx context.Context) error { + for i := 0; i < 10; i++ { + select { + case <-ctx.Done(): + return ctx.Err() + default: + time.Sleep(2 * time.Millisecond) + } + } + return nil + }); err != nil { + t.Error(err) + } + + timeout := 5 * time.Millisecond + + defer func() { + cancel() + + if err := WaitWithTimeout(ctx, timeout); !errors.Is(err, context.DeadlineExceeded) { + t.Error("expected timeout error") + } + }() +} diff --git a/example_test.go b/example_test.go index 8aad64c..548f2fc 100644 --- a/example_test.go +++ b/example_test.go @@ -12,7 +12,7 @@ import ( func Example() { ctx, cancel := donegroup.WithCancel(context.Background()) - if err := donegroup.Clenup(ctx, func() error { + if err := donegroup.Clenup(ctx, func(_ context.Context) error { // Cleanup process of some kind time.Sleep(10 * time.Millisecond) fmt.Println("cleanup")