diff --git a/internal/handler/exec.go b/internal/handler/exec.go index eb1a209..866e803 100644 --- a/internal/handler/exec.go +++ b/internal/handler/exec.go @@ -27,7 +27,8 @@ func ExecHandler(cmd string) error { return err } - return command.ExecCommand( + exeCmd, err := command.ExecCommand( + context.Background(), ecsService, cluster, *task.TaskArn, @@ -35,4 +36,10 @@ func ExecHandler(cmd string) error { container.Name, cfg.Region, ) + + if err != nil { + return err + } + + return exeCmd.Run() } diff --git a/internal/handler/portforward.go b/internal/handler/portforward.go index 7c84b84..6224ffa 100644 --- a/internal/handler/portforward.go +++ b/internal/handler/portforward.go @@ -32,7 +32,8 @@ func PortforwardHandler(doc command.DocumentName, params map[string][]string) er taskId := strings.Split(*task.TaskArn, "/")[2] - return command.PortForwardCommand( + cmd, err := command.PortForwardCommand( + context.Background(), ssmService, cluster, taskId, @@ -41,4 +42,10 @@ func PortforwardHandler(doc command.DocumentName, params map[string][]string) er doc, params, ) + + if err != nil { + return err + } + + return cmd.Run() } diff --git a/internal/session_manager/session.go b/internal/session_manager/session.go index d5349c6..aa91865 100644 --- a/internal/session_manager/session.go +++ b/internal/session_manager/session.go @@ -1,14 +1,18 @@ package session_manager -import "os/exec" +import ( + "context" + "os/exec" +) const SESSION_MANAGER_COMMAND = "session-manager-plugin" -func MakeStartSessionCmd(response string, region string) *exec.Cmd { +func MakeStartSessionCmd(ctx context.Context, response string, region string) *exec.Cmd { const OperationName = "StartSession" // https://github.com/aws/session-manager-plugin/blob/1.2.463.0/src/sessionmanagerplugin/session/session.go#L163-L178 - return exec.Command( + return exec.CommandContext( + ctx, SESSION_MANAGER_COMMAND, response, region, diff --git a/pkg/command/exec.go b/pkg/command/exec.go index 0efe3c9..b36c3bf 100644 --- a/pkg/command/exec.go +++ b/pkg/command/exec.go @@ -4,13 +4,14 @@ import ( "context" "encoding/json" "os" + "os/exec" "github.com/aws/aws-sdk-go-v2/service/ecs" "github.com/aws/aws-sdk-go/aws" "github.com/wim-web/tonneeeeel/internal/session_manager" ) -func ExecCommand(c *ecs.Client, cluster string, task string, command string, container *string, region string) error { +func ExecCommand(ctx context.Context, c *ecs.Client, cluster string, task string, command string, container *string, region string) (*exec.Cmd, error) { input := &ecs.ExecuteCommandInput{ Cluster: aws.String(cluster), Task: aws.String(task), @@ -22,20 +23,20 @@ func ExecCommand(c *ecs.Client, cluster string, task string, command string, con res, err := c.ExecuteCommand(context.Background(), input) if err != nil { - return err + return nil, err } r, err := json.Marshal(res.Session) if err != nil { - return err + return nil, err } - cmd := session_manager.MakeStartSessionCmd(string(r), region) + cmd := session_manager.MakeStartSessionCmd(ctx, string(r), region) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd.Run() + return cmd, nil } diff --git a/pkg/command/portforward.go b/pkg/command/portforward.go index 5a9c63d..ad7619b 100644 --- a/pkg/command/portforward.go +++ b/pkg/command/portforward.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "os" + "os/exec" "github.com/aws/aws-sdk-go-v2/service/ssm" "github.com/aws/aws-sdk-go/aws" @@ -18,7 +19,7 @@ const ( REMOTE_PORT_FORWARD_DOCUMENT_NAME DocumentName = "AWS-StartPortForwardingSessionToRemoteHost" ) -func PortForwardCommand(c *ssm.Client, cluster string, taskId string, containerId string, region string, doc DocumentName, params map[string][]string) error { +func PortForwardCommand(ctx context.Context, c *ssm.Client, cluster string, taskId string, containerId string, region string, doc DocumentName, params map[string][]string) (*exec.Cmd, error) { input := &ssm.StartSessionInput{ Target: aws.String(fmt.Sprintf("ecs:%s_%s_%s", cluster, taskId, containerId)), DocumentName: aws.String(string(doc)), @@ -28,20 +29,20 @@ func PortForwardCommand(c *ssm.Client, cluster string, taskId string, containerI res, err := c.StartSession(context.Background(), input) if err != nil { - return err + return nil, err } r, err := json.Marshal(res) if err != nil { - return err + return nil, err } - cmd := session_manager.MakeStartSessionCmd(string(r), region) + cmd := session_manager.MakeStartSessionCmd(ctx, string(r), region) cmd.Stdout = os.Stdout cmd.Stdin = os.Stdin cmd.Stderr = os.Stderr - return cmd.Run() + return cmd, nil }