Skip to content

Commit

Permalink
Merge pull request #2 from k1LoW/timeout
Browse files Browse the repository at this point in the history
Add WaitWithTimeout
  • Loading branch information
k1LoW authored Feb 5, 2024
2 parents 01137a1 + 4fe69b7 commit 3441717
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 15 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import (
func main() {
ctx, cancel := donegroup.WithCancel(context.Background())

if err := donegroup.Clenup(ctx, func() error {
if err := donegroup.Clenup(ctx, func(_ context.Context) error {
// Cleanup process of some kind
fmt.Println("cleanup")
return nil
Expand Down
44 changes: 37 additions & 7 deletions donegroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ import (
"context"
"errors"
"sync"
"time"

"golang.org/x/sync/errgroup"
)

var doneGroupKey = struct{}{}

type doneGroup struct {
ctx context.Context
cleanupGroups []*errgroup.Group
mu sync.Mutex
}
Expand All @@ -34,21 +36,21 @@ func WithCancelWithKey(ctx context.Context, key any) (context.Context, context.C
}

// Clenup runs f when the context is canceled.
func Clenup(ctx context.Context, f func() error) error {
func Clenup(ctx context.Context, f func(ctx context.Context) error) error {
return ClenupWithKey(ctx, doneGroupKey, f)
}

// ClenupWithKey runs f when the context is canceled.
func ClenupWithKey(ctx context.Context, key any, f func() error) error {
func ClenupWithKey(ctx context.Context, key any, f func(ctx context.Context) error) error {
dg, ok := ctx.Value(key).(*doneGroup)
if !ok {
return errors.New("donegroup: context does not contain a donegroup. Use donegroup.WithCancel to create a context with a donegroup")
return errors.New("donegroup: context does not contain a doneGroup. Use donegroup.WithCancel to create a context with a doneGroup")
}

first := dg.cleanupGroups[0]
first.Go(func() error {
<-ctx.Done()
return f()
return dg.goWithCtx(f)
})
return nil
}
Expand All @@ -58,16 +60,44 @@ func Wait(ctx context.Context) error {
return WaitWithKey(ctx, doneGroupKey)
}

// Wait blocks until the context is canceled or the timeout is reached.
func WaitWithTimeout(ctx context.Context, timeout time.Duration) error {
return WaitWithKeyAndTimeout(ctx, doneGroupKey, timeout)
}

// WaitWithKey blocks until the context is canceled.
func WaitWithKey(ctx context.Context, key any) error {
<-ctx.Done()
return WaitWithKeyAndTimeout(ctx, key, 0)
}

// WaitWithKeyAndTimeout blocks until the context is canceled or the timeout is reached.
func WaitWithKeyAndTimeout(ctx context.Context, key any, timeout time.Duration) error {
dg, ok := ctx.Value(key).(*doneGroup)
if !ok {
return errors.New("donegroup: context does not contain a donegroup. Use donegroup.WithCancel to create a context with a donegroup")
return errors.New("donegroup: context does not contain a doneGroup. Use donegroup.WithCancel to create a context with a doneGroup")
}
eg := new(errgroup.Group)
ctxx := context.Background()
var cancel context.CancelFunc
if timeout != 0 {
ctxx, cancel = context.WithTimeout(ctxx, timeout)
defer cancel()
}
dg.mu.Lock()
dg.ctx = ctxx
dg.mu.Unlock()

<-ctx.Done()
eg, _ := errgroup.WithContext(ctxx)
for _, g := range dg.cleanupGroups {
eg.Go(g.Wait)
}

return eg.Wait()
}

func (dg *doneGroup) goWithCtx(f func(ctx context.Context) error) error {
dg.mu.Lock()
ctx := dg.ctx
dg.mu.Unlock()
return f(ctx)
}
42 changes: 36 additions & 6 deletions donegroup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package donegroup

import (
"context"
"errors"
"sync"
"testing"
"time"
Expand All @@ -13,7 +14,7 @@ func TestDoneGroup(t *testing.T) {

cleanup := false

if err := Clenup(ctx, func() error {
if err := Clenup(ctx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
cleanup = true
return nil
Expand Down Expand Up @@ -44,7 +45,7 @@ func TestMultiCleanup(t *testing.T) {
cleanup := 0

for i := 0; i < 10; i++ {
if err := Clenup(ctx, func() error {
if err := Clenup(ctx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestNested(t *testing.T) {
secondCleanup := 0

for i := 0; i < 10; i++ {
if err := Clenup(firstCtx, func() error {
if err := Clenup(firstCtx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
Expand All @@ -90,7 +91,7 @@ func TestNested(t *testing.T) {
}

for i := 0; i < 5; i++ {
if err := Clenup(secondCtx, func() error {
if err := Clenup(secondCtx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -137,7 +138,7 @@ func TestRootWaitAll(t *testing.T) {
leafCleanup := 0

for i := 0; i < 10; i++ {
if err := Clenup(rootCtx, func() error {
if err := Clenup(rootCtx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
Expand All @@ -149,7 +150,7 @@ func TestRootWaitAll(t *testing.T) {
}

for i := 0; i < 5; i++ {
if err := Clenup(leafCtx, func() error {
if err := Clenup(leafCtx, func(_ context.Context) error {
time.Sleep(10 * time.Millisecond)
mu.Lock()
defer mu.Unlock()
Expand Down Expand Up @@ -180,3 +181,32 @@ func TestRootWaitAll(t *testing.T) {
}
}()
}

func TestWaitWithTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := WithCancel(context.Background())

if err := Clenup(ctx, func(ctx context.Context) error {
for i := 0; i < 10; i++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
time.Sleep(2 * time.Millisecond)
}
}
return nil
}); err != nil {
t.Error(err)
}

timeout := 5 * time.Millisecond

defer func() {
cancel()

if err := WaitWithTimeout(ctx, timeout); !errors.Is(err, context.DeadlineExceeded) {
t.Error("expected timeout error")
}
}()
}
2 changes: 1 addition & 1 deletion example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
func Example() {
ctx, cancel := donegroup.WithCancel(context.Background())

if err := donegroup.Clenup(ctx, func() error {
if err := donegroup.Clenup(ctx, func(_ context.Context) error {
// Cleanup process of some kind
time.Sleep(10 * time.Millisecond)
fmt.Println("cleanup")
Expand Down

0 comments on commit 3441717

Please sign in to comment.