Skip to content

Commit

Permalink
Merge pull request #1 from uber-go/master
Browse files Browse the repository at this point in the history
Merge branch 'master' from original repo
  • Loading branch information
abramlab authored Dec 8, 2022
2 parents c9b24f8 + b379e13 commit 5eb0e76
Show file tree
Hide file tree
Showing 9 changed files with 474 additions and 35 deletions.
16 changes: 12 additions & 4 deletions annotated_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1485,7 +1485,8 @@ func assertApp(
invoked *bool,
) {
t.Helper()
ctx := context.Background()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
assert.False(t, *started)
require.NoError(t, app.Start(ctx))
assert.True(t, *started)
Expand Down Expand Up @@ -1517,8 +1518,11 @@ func TestHookAnnotations(t *testing.T) {
t.Run("with hook on invoke", func(t *testing.T) {
t.Parallel()

var started bool
var invoked bool
var (
started bool
stopped bool
invoked bool
)
hook := fx.Annotate(
func() {
invoked = true
Expand All @@ -1527,10 +1531,14 @@ func TestHookAnnotations(t *testing.T) {
started = true
return nil
}),
fx.OnStop(func(context.Context) error {
stopped = true
return nil
}),
)
app := fxtest.New(t, fx.Invoke(hook))

assertApp(t, app, &started, nil, &invoked)
assertApp(t, app, &started, &stopped, &invoked)
})

t.Run("depend on result interface of target", func(t *testing.T) {
Expand Down
43 changes: 39 additions & 4 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -619,9 +619,13 @@ func (app *App) Start(ctx context.Context) (err error) {
})
}

func (app *App) start(ctx context.Context) error {
if err := app.lifecycle.Start(ctx); err != nil {
// Start failed, rolling back.
// withRollback will execute an anonymous function with a given context.
// if the anon func returns an error, rollback methods will be called and related events emitted
func (app *App) withRollback(
ctx context.Context,
f func(context.Context) error,
) error {
if err := f(ctx); err != nil {
app.log().LogEvent(&fxevent.RollingBack{StartErr: err})

stopErr := app.lifecycle.Stop(ctx)
Expand All @@ -633,9 +637,20 @@ func (app *App) start(ctx context.Context) error {

return err
}

return nil
}

func (app *App) start(ctx context.Context) error {
return app.withRollback(ctx, func(ctx context.Context) error {
if err := app.lifecycle.Start(ctx); err != nil {
return err
}
app.receivers.Start(ctx)
return nil
})
}

// Stop gracefully stops the application. It executes any registered OnStop
// hooks in reverse order, so that each constructor's stop hooks are called
// before its dependencies' stop hooks.
Expand All @@ -648,9 +663,14 @@ func (app *App) Stop(ctx context.Context) (err error) {
app.log().LogEvent(&fxevent.Stopped{Err: err})
}()

cb := func(ctx context.Context) error {
defer app.receivers.Stop(ctx)
return app.lifecycle.Stop(ctx)
}

return withTimeout(ctx, &withTimeoutParams{
hook: _onStopHook,
callback: app.lifecycle.Stop,
callback: cb,
lifecycle: app.lifecycle,
log: app.log(),
})
Expand All @@ -663,10 +683,25 @@ func (app *App) Stop(ctx context.Context) (err error) {
//
// Alternatively, a signal can be broadcast to all done channels manually by
// using the Shutdown functionality (see the Shutdowner documentation for details).
//
// Note: The channel Done returns will not receive a signal unless the application
// as been started via Start or Run.
func (app *App) Done() <-chan os.Signal {
return app.receivers.Done()
}

// Wait returns a channel of [ShutdownSignal] to block on after starting the
// application and function, similar to [App.Done], but with a minor difference.
// Should an ExitCode be provided as a [ShutdownOption] to
// the Shutdowner Shutdown method, the exit code will be available as part
// of the ShutdownSignal struct.
//
// Should the app receive a SIGTERM or SIGINT, the given
// signal will be populated in the ShutdownSignal struct.
func (app *App) Wait() <-chan ShutdownSignal {
return app.receivers.Wait()
}

// StartTimeout returns the configured startup timeout. Apps default to using
// DefaultTimeout, but users can configure this behavior using the
// StartTimeout option.
Expand Down
23 changes: 23 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,29 @@ func TestAppStart(t *testing.T) {
err := app.Start(context.Background()).Error()
assert.Contains(t, err, "OnStart hook added by go.uber.org/fx_test.TestAppStart.func10.1 failed: goroutine exited without returning")
})

t.Run("StartTwiceWithHooksErrors", func(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

app := fxtest.New(t,
Invoke(func(lc Lifecycle) {
lc.Append(Hook{
OnStart: func(ctx context.Context) error { return nil },
OnStop: func(ctx context.Context) error { return nil },
})
}),
)
assert.NoError(t, app.Start(ctx))
err := app.Start(ctx)
if assert.Error(t, err) {
assert.ErrorContains(t, err, "attempted to start lifecycle when in state: started")
}
app.Stop(ctx)
assert.NoError(t, app.Start(ctx))
})
}

func TestAppStop(t *testing.T) {
Expand Down
57 changes: 57 additions & 0 deletions internal/lifecycle/lifecycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,38 @@ type Hook struct {
callerFrame fxreflect.Frame
}

type appState int

const (
stopped appState = iota
starting
incompleteStart
started
stopping
)

func (as appState) String() string {
switch as {
case stopped:
return "stopped"
case starting:
return "starting"
case incompleteStart:
return "incompleteStart"
case started:
return "started"
case stopping:
return "stopping"
default:
return "invalidState"
}
}

// Lifecycle coordinates application lifecycle hooks.
type Lifecycle struct {
clock fxclock.Clock
logger fxevent.Logger
state appState
hooks []Hook
numStarted int
startRecords HookRecords
Expand Down Expand Up @@ -157,9 +185,23 @@ func (l *Lifecycle) Start(ctx context.Context) error {
}

l.mu.Lock()
if l.state != stopped {
defer l.mu.Unlock()
return fmt.Errorf("attempted to start lifecycle when in state: %v", l.state)
}
l.numStarted = 0
l.state = starting

l.startRecords = make(HookRecords, 0, len(l.hooks))
l.mu.Unlock()

var returnState appState = incompleteStart
defer func() {
l.mu.Lock()
l.state = returnState
l.mu.Unlock()
}()

for _, hook := range l.hooks {
// if ctx has cancelled, bail out of the loop.
if err := ctx.Err(); err != nil {
Expand Down Expand Up @@ -187,6 +229,7 @@ func (l *Lifecycle) Start(ctx context.Context) error {
l.numStarted++
}

returnState = started
return nil
}

Expand Down Expand Up @@ -221,6 +264,20 @@ func (l *Lifecycle) Stop(ctx context.Context) error {
return errors.New("called OnStop with nil context")
}

l.mu.Lock()
if l.state != started && l.state != incompleteStart {
defer l.mu.Unlock()
return fmt.Errorf("attempted to stop lifecycle when in state: %v", l.state)
}
l.state = stopping
l.mu.Unlock()

defer func() {
l.mu.Lock()
l.state = stopped
l.mu.Unlock()
}()

l.mu.Lock()
l.stopRecords = make(HookRecords, 0, l.numStarted)
l.mu.Unlock()
Expand Down
24 changes: 24 additions & 0 deletions internal/lifecycle/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ func TestLifecycleStart(t *testing.T) {
assert.NoError(t, l.Start(context.Background()))
assert.Equal(t, 2, count)
})

t.Run("ErrHaltsChainAndRollsBack", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -143,6 +144,18 @@ func TestLifecycleStart(t *testing.T) {
// stop hooks.
require.NoError(t, l.Stop(ctx))
})

t.Run("StartWhileStartedErrors", func(t *testing.T) {
t.Parallel()

l := New(testLogger(t), fxclock.System)
assert.NoError(t, l.Start(context.Background()))
err := l.Start(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "attempted to start lifecycle when in state: started")
assert.NoError(t, l.Stop(context.Background()))
assert.NoError(t, l.Start(context.Background()))
})
}

func TestLifecycleStop(t *testing.T) {
Expand All @@ -152,6 +165,7 @@ func TestLifecycleStop(t *testing.T) {
t.Parallel()

l := New(testLogger(t), fxclock.System)
l.Start(context.Background())
assert.Nil(t, l.Stop(context.Background()), "no lifecycle hooks should have resulted in stop returning nil")
})

Expand Down Expand Up @@ -317,6 +331,16 @@ func TestLifecycleStop(t *testing.T) {
assert.Contains(t, err.Error(), "called OnStop with nil context")

})

t.Run("StopWhileStoppedErrors", func(t *testing.T) {
t.Parallel()

l := New(testLogger(t), fxclock.System)
err := l.Stop(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "attempted to stop lifecycle when in state: stopped")
})

}

func TestHookRecordsFormat(t *testing.T) {
Expand Down
63 changes: 61 additions & 2 deletions shutdown.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

package fx

import (
"context"
"time"
)

// Shutdowner provides a method that can manually trigger the shutdown of the
// application by sending a signal to all open Done channels. Shutdowner works
// on applications using Run as well as Start, Done, and Stop. The Shutdowner is
Expand All @@ -34,8 +39,42 @@ type ShutdownOption interface {
apply(*shutdowner)
}

type exitCodeOption int

func (code exitCodeOption) apply(s *shutdowner) {
s.exitCode = int(code)
}

var _ ShutdownOption = exitCodeOption(0)

// ExitCode is a [ShutdownOption] that may be passed to the Shutdown method of the
// [Shutdowner] interface.
// The given integer exit code will be broadcasted to any receiver waiting
// on a [ShutdownSignal] from the [Wait] method.
func ExitCode(code int) ShutdownOption {
return exitCodeOption(code)
}

type shutdownTimeoutOption time.Duration

func (to shutdownTimeoutOption) apply(s *shutdowner) {
s.shutdownTimeout = time.Duration(to)
}

var _ ShutdownOption = shutdownTimeoutOption(0)

// ShutdownTimeout is a [ShutdownOption] that allows users to specify a timeout
// for a given call to Shutdown method of the [Shutdowner] interface. As the
// Shutdown method will block while waiting for a signal receiver relay
// goroutine to stop.
func ShutdownTimeout(timeout time.Duration) ShutdownOption {
return shutdownTimeoutOption(timeout)
}

type shutdowner struct {
app *App
app *App
exitCode int
shutdownTimeout time.Duration
}

// Shutdown broadcasts a signal to all of the application's Done channels
Expand All @@ -44,7 +83,27 @@ type shutdowner struct {
// In practice this means Shutdowner.Shutdown should not be called from an
// fx.Invoke, but from a fx.Lifecycle.OnStart hook.
func (s *shutdowner) Shutdown(opts ...ShutdownOption) error {
return s.app.receivers.Broadcast(ShutdownSignal{Signal: _sigTERM})
for _, opt := range opts {
opt.apply(s)
}

ctx := context.Background()

if s.shutdownTimeout != time.Duration(0) {
c, cancel := context.WithTimeout(
context.Background(),
s.shutdownTimeout,
)
defer cancel()
ctx = c
}

defer s.app.receivers.Stop(ctx)

return s.app.receivers.Broadcast(ShutdownSignal{
Signal: _sigTERM,
ExitCode: s.exitCode,
})
}

func (app *App) shutdowner() Shutdowner {
Expand Down
Loading

0 comments on commit 5eb0e76

Please sign in to comment.