Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: translate TaskRunner #1

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ linters:
- exportloopref
- forbidigo
- funlen
- gci
- gochecknoglobals
- gocognit
- goconst
Expand Down
2 changes: 1 addition & 1 deletion .tool-versions
Original file line number Diff line number Diff line change
@@ -1 +1 @@
golang 1.23.1
golang 1.23.3
30 changes: 30 additions & 0 deletions src/aws/cloudwatch.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package aws

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
)

type cloudwatchLogsClientAPI interface {
GetLogEvents(ctx context.Context, params *cloudwatchlogs.GetLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.GetLogEventsOutput, error)
}

type LogDetails struct {
logGroupName string
logStreamName string
}

func RetrieveLogs(ctx context.Context, cloudwatchLogsClientAPI cloudwatchLogsClientAPI, loggingDetails LogDetails) ([]types.OutputLogEvent, error) {
response, err := cloudwatchLogsClientAPI.GetLogEvents(ctx, &cloudwatchlogs.GetLogEventsInput{
LogStreamName: &loggingDetails.logStreamName,
LogGroupName: &loggingDetails.logGroupName,
StartFromHead: aws.Bool(true),
})
if err != nil {
return []types.OutputLogEvent{}, err
}
return response.Events, nil
}
100 changes: 100 additions & 0 deletions src/aws/cloudwatch_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package aws

import (
"context"
"errors"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs"
"github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type mockCloudwatchLogsClient struct {
mockGetLogEvents func(ctx context.Context, params *cloudwatchlogs.GetLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.GetLogEventsOutput, error)
}

func (m mockCloudwatchLogsClient) GetLogEvents(ctx context.Context, params *cloudwatchlogs.GetLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.GetLogEventsOutput, error) {
return m.mockGetLogEvents(ctx, params, optFns...)
}

func TestRetrieveLogs(t *testing.T) {
input := LogDetails{
logGroupName: "test-group",
logStreamName: "test-stream/test-container/07cc583696bd44e0be450bff7314ddaf",
}

events := []types.OutputLogEvent{
{
IngestionTime: aws.Int64(0),
Message: aws.String("beans have been detected in the system"),
Timestamp: aws.Int64(0),
},
{
IngestionTime: aws.Int64(1),
Message: aws.String("beans have been removed from the system"),
Timestamp: aws.Int64(1),
},
}

positiveTests := []struct {
name string
input LogDetails
client mockCloudwatchLogsClient
expected []types.OutputLogEvent
}{
{
// This is a regression test to ensure the function signature remains the same
name: "given a valid LogDetails input, return the events of the log stream",
input: input,
client: mockCloudwatchLogsClient{
mockGetLogEvents: func(ctx context.Context, params *cloudwatchlogs.GetLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.GetLogEventsOutput, error) {
return &cloudwatchlogs.GetLogEventsOutput{Events: events}, nil
},
},
expected: events,
},
}

for _, tc := range positiveTests {
t.Run(tc.name, func(t *testing.T) {
result, err := RetrieveLogs(context.TODO(), tc.client, tc.input)

t.Logf("result: %v", result)
t.Logf("expected: %v", tc.expected)
require.NoError(t, err)
assert.Equal(t, tc.expected, result)
})
}

negativeTests := []struct {
name string
input LogDetails
client mockCloudwatchLogsClient
expected []types.OutputLogEvent
}{
{
name: "when the underlying cloudwatch client experiences an error, return it in the function ",
input: input,
client: mockCloudwatchLogsClient{
mockGetLogEvents: func(ctx context.Context, params *cloudwatchlogs.GetLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.GetLogEventsOutput, error) {
return &cloudwatchlogs.GetLogEventsOutput{}, errors.New("generic cloudwatch error")
},
},
expected: []types.OutputLogEvent{},
},
}

for _, tc := range negativeTests {
t.Run(tc.name, func(t *testing.T) {
result, err := RetrieveLogs(context.TODO(), tc.client, tc.input)

t.Logf("result: %v", result)
t.Logf("expected: %v", tc.expected)
require.Error(t, err)
assert.Equal(t, tc.expected, result)
})
}
}
120 changes: 120 additions & 0 deletions src/aws/ecs.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package aws

import (
"context"
"encoding/json"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecs"
"github.com/aws/aws-sdk-go-v2/service/ecs/types"
)

// internal interface for ecs
type EcsClientAPI interface {
RunTask(ctx context.Context, params *ecs.RunTaskInput, optFns ...func(*ecs.Options)) (*ecs.RunTaskOutput, error)
DescribeTasks(ctx context.Context, params *ecs.DescribeTasksInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTasksOutput, error)
DescribeTaskDefinition(ctx context.Context, params *ecs.DescribeTaskDefinitionInput, optFns ...func(*ecs.Options)) (*ecs.DescribeTaskDefinitionOutput, error)
}

type ecsWaiterAPI interface {
WaitForOutput(ctx context.Context, params *ecs.DescribeTasksInput, maxWaitDur time.Duration, optFns ...func(*ecs.TasksStoppedWaiterOptions)) (*ecs.DescribeTasksOutput, error)
}

func SubmitTask(ctx context.Context, ecsAPI EcsClientAPI, input *TaskRunnerConfiguration) (string, error) {
response, err := ecsAPI.RunTask(ctx, &ecs.RunTaskInput{
Cluster: &input.Cluster,
LaunchType: "FARGATE",
Overrides: &types.TaskOverride{
ContainerOverrides: []types.ContainerOverride{
{
Name: aws.String("migrations-runner"),
Command: input.Command,
},
},
},
TaskDefinition: &input.TaskDefinitionArn,
NetworkConfiguration: &types.NetworkConfiguration{
AwsvpcConfiguration: &types.AwsVpcConfiguration{
Subnets: input.SubnetIds,
SecurityGroups: input.SecurityGroupIds,
},
},
})
if err != nil {
return "", err
}

if response.Tasks[0].TaskArn == nil {
responseJSON, err := json.Marshal(response)
if err != nil {
return "", fmt.Errorf("error in unmarshalling response for failed RunTask: %w", err)
}
return "", fmt.Errorf("ecs:RunTask response contains no TaskArn: %v", string(responseJSON))
}

// this is working on the assumption that only one task is returned
return *response.Tasks[0].TaskArn, nil
}

func WaitForCompletion(ctx context.Context, waiter ecsWaiterAPI, taskArn string) (*ecs.DescribeTasksOutput, error) {
cluster := ClusterFromTaskArn(taskArn)

// TODO: This magic number will be resolved in a future piece of work, not going to refactor this right now
maxWaitDuration := 15 * time.Minute //nolint:mnd
result, err := waiter.WaitForOutput(ctx, &ecs.DescribeTasksInput{
Cluster: aws.String(cluster),
Tasks: []string{taskArn},
}, maxWaitDuration)

// the `DescribeTasksOutput` struct is returned even if there is an error. Counterintuitively, it happens to include failure information
// which we may want to surface from the `Failures` struct field
if err != nil {
return result, err
}

// In a successful scenario, we should have a `tasks` slice with a single element
return result, nil
}

func ClusterFromTaskArn(arn string) string {
parts := strings.Split(arn, "/")
return parts[len(parts)-2]
}

func TaskIDFromArn(taskArn string) string {
parts := strings.Split(taskArn, "/")
return parts[len(parts)-1]
}

// Acquires LogStream details for given ECS Task
func FindLogStreamFromTask(ctx context.Context, ecsClientAPI EcsClientAPI, task types.Task) (LogDetails, error) {
response, err := ecsClientAPI.DescribeTaskDefinition(ctx, &ecs.DescribeTaskDefinitionInput{
TaskDefinition: task.TaskDefinitionArn,
})
if err != nil {
return LogDetails{}, err
}

// TODO: This was originally part of the if statement below, but it was moved out to avoid a nil pointer dereference when getting the `logGroupName` and `streamPrefix` values
if len(response.TaskDefinition.ContainerDefinitions) == 0 {
return LogDetails{}, fmt.Errorf("ecs:DescribeTaskDefinition response is missing ContainerDefinitions data: %v", response)
}

container := response.TaskDefinition.ContainerDefinitions[0] // assume first container is the application container
logGroupName := container.LogConfiguration.Options["awslogs-group"]
//NOTE: Takes the format: prefix-name/container-name/ecs-task-id
streamPrefix := container.LogConfiguration.Options["awslogs-stream-prefix"]

// We need the logGroupName, streamPrefix, and a container name to be able to produce a FindLogStreamOutput in full
if logGroupName == "" || streamPrefix == "" {
return LogDetails{}, fmt.Errorf("ecs:DescribeTaskDefinition response does not conftain required logging configuration: %v", response)
}

return LogDetails{
logGroupName: logGroupName,
logStreamName: fmt.Sprintf("%s/%s/%s", streamPrefix, *container.Name, TaskIDFromArn(*task.TaskArn)),
}, nil
}
Loading
Loading