diff --git a/pubsub/aws/snssqs/snssqs.go b/pubsub/aws/snssqs/snssqs.go index 23a31409c7..024f7629ee 100644 --- a/pubsub/aws/snssqs/snssqs.go +++ b/pubsub/aws/snssqs/snssqs.go @@ -7,10 +7,12 @@ import ( "fmt" "strconv" "strings" + "sync" "github.com/aws/aws-sdk-go/aws" - sns "github.com/aws/aws-sdk-go/service/sns" - sqs "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sns" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sts" gonanoid "github.com/matoous/go-nanoid/v2" @@ -21,15 +23,16 @@ import ( type snsSqs struct { // key is the topic name, value is the ARN of the topic. - topics map[string]string + topics sync.Map // key is the sanitized topic name, value is the actual topic name. - topicSanitized map[string]string + topicsSanitized sync.Map // key is the topic name, value holds the ARN of the queue and its url. - queues map[string]*sqsQueueInfo - // sns to sqs subscriptions. - subscriptions []*string + queues sync.Map + // key is a composite key of queue ARN and topic ARN mapping to subscription ARN. + subscriptions sync.Map snsClient *sns.SNS sqsClient *sqs.SQS + stsClient *sts.STS metadata *snsSqsMetadata logger logger.Logger id string @@ -72,6 +75,10 @@ type snsSqsMetadata struct { messageWaitTimeSeconds int64 // maximum number of messages to receive from the queue at a time. Default: 10, Maximum: 10. messageMaxNumber int64 + // disable resource provisioning of SNS and SQS. + disableEntityManagement bool + // aws account ID. + accountID string } const ( @@ -89,9 +96,8 @@ func NewSnsSqs(l logger.Logger) pubsub.PubSub { } return &snsSqs{ - logger: l, - subscriptions: []*string{}, - id: id, + logger: l, + id: id, } } @@ -174,12 +180,12 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, } if val, ok := getAliasedProperty([]string{"awsAccountID", "accessKey"}, metadata); ok { - s.logger.Debugf("AccessKey: %s", val) + s.logger.Debugf("accessKey: %s", val) md.AccessKey = val } if val, ok := getAliasedProperty([]string{"awsSecret", "secretKey"}, metadata); ok { - s.logger.Debugf("awsToken: %s", val) + s.logger.Debugf("secretKey: %s", val) md.SecretKey = val } @@ -288,6 +294,14 @@ func (s *snsSqs) getSnsSqsMetatdata(metadata pubsub.Metadata) (*snsSqsMetadata, md.messageMaxNumber = maxNumber } + if val, ok := props["disableEntityManagement"]; ok { + parsed, err := parseBool(val, "disableEntityManagement") + if err != nil { + return nil, err + } + md.disableEntityManagement = parsed + } + return &md, nil } @@ -299,22 +313,38 @@ func (s *snsSqs) Init(metadata pubsub.Metadata) error { s.metadata = md - // both Publish and Subscribe need reference the topic ARN - // track these ARNs in this map. - s.topics = make(map[string]string) - s.topicSanitized = make(map[string]string) - s.queues = make(map[string]*sqsQueueInfo) + // both Publish and Subscribe need reference the topic ARN, queue ARN and subscription ARN between topic and queue + // track these ARNs in these maps. + s.topics = sync.Map{} + s.topicsSanitized = sync.Map{} + s.queues = sync.Map{} + s.subscriptions = sync.Map{} + sess, err := aws_auth.GetClient(md.AccessKey, md.SecretKey, md.SessionToken, md.Region, md.Endpoint) if err != nil { return fmt.Errorf("error creating an AWS client: %w", err) } + + s.stsClient = sts.New(sess) + callerIDOutput, err := s.stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + if err != nil { + return fmt.Errorf("error fetching sts caller ID: %w", err) + } + + s.metadata.accountID = *callerIDOutput.Account + s.snsClient = sns.New(sess) s.sqsClient = sqs.New(sess) return nil } -func (s *snsSqs) createTopic(topic string) (string, string, error) { +func (s *snsSqs) buildARN(serviceName, entityName string) string { + // arn:aws:sns:us-east-1:302212680347:aws-controltower-SecurityNotifications + return fmt.Sprintf("arn:aws:%s:%s:%s:%s", serviceName, s.metadata.Region, s.metadata.accountID, entityName) +} + +func (s *snsSqs) createTopic(topic string) (string, error) { sanitizedName := nameToAWSSanitizedName(topic, s.metadata.fifo) snsCreateTopicInput := &sns.CreateTopicInput{ Name: aws.String(sanitizedName), @@ -328,35 +358,58 @@ func (s *snsSqs) createTopic(topic string) (string, string, error) { createTopicResponse, err := s.snsClient.CreateTopic(snsCreateTopicInput) if err != nil { - return "", "", fmt.Errorf("error while creating an SNS topic: %w", err) + return "", fmt.Errorf("error while creating an SNS topic: %w", err) + } + + return *(createTopicResponse.TopicArn), nil +} + +func (s *snsSqs) getTopicArn(topic string) (string, error) { + arn := s.buildARN("sns", topic) + getTopicOutput, err := s.snsClient.GetTopicAttributes(&sns.GetTopicAttributesInput{TopicArn: aws.String(arn)}) + if err != nil { + return "", fmt.Errorf("error: %w while getting topic: %v with arn: %v", err, topic, arn) } - return *(createTopicResponse.TopicArn), sanitizedName, nil + return *getTopicOutput.Attributes["TopicArn"], nil } // get the topic ARN from the topics map. If it doesn't exist in the map, try to fetch it from AWS, if it doesn't exist // at all, issue a request to create the topic. func (s *snsSqs) getOrCreateTopic(topic string) (string, error) { - topicArn, ok := s.topics[topic] + var ( + err error + topicArn string + ) - if ok { - s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArn) + if topicArnCached, ok := s.topics.Load(topic); ok { + s.logger.Debugf("found existing topic ARN for topic %s: %s", topic, topicArnCached) - return topicArn, nil + return topicArnCached.(string), nil } + // creating queues is idempotent, the names serve as unique keys among a given region. + s.logger.Debugf("No SNS topic arn found for %s\nCreating SNS topic", topic) - s.logger.Debugf("no topic ARN found for %s\n Creating topic instead.", topic) + sanitizedName := nameToAWSSanitizedName(topic, s.metadata.fifo) + if !s.metadata.disableEntityManagement { + topicArn, err = s.createTopic(sanitizedName) + if err != nil { + s.logger.Errorf("error creating new topic %s: %w", topic, err) - topicArn, sanitizedName, err := s.createTopic(topic) - if err != nil { - s.logger.Errorf("error creating new topic %s: %v", topic, err) + return "", err + } + } else { + topicArn, err = s.getTopicArn(sanitizedName) + if err != nil { + s.logger.Errorf("error fetching info for topic %s: %w", topic, err) - return "", err + return "", err + } } // record topic ARN. - s.topics[topic] = topicArn - s.topicSanitized[sanitizedName] = topic + s.topics.Store(topic, topicArn) + s.topicsSanitized.Store(sanitizedName, topic) return topicArn, nil } @@ -392,25 +445,56 @@ func (s *snsSqs) createQueue(queueName string) (*sqsQueueInfo, error) { }, nil } +func (s *snsSqs) getQueueArn(queueName string) (*sqsQueueInfo, error) { + queueURLOutput, err := s.sqsClient.GetQueueUrl(&sqs.GetQueueUrlInput{QueueName: aws.String(queueName), QueueOwnerAWSAccountId: aws.String(s.metadata.accountID)}) + if err != nil { + return nil, fmt.Errorf("error: %w while getting url of queue: %s", err, queueName) + } + url := queueURLOutput.QueueUrl + + var getQueueOutput *sqs.GetQueueAttributesOutput + getQueueOutput, err = s.sqsClient.GetQueueAttributes(&sqs.GetQueueAttributesInput{QueueUrl: url, AttributeNames: []*string{aws.String("QueueArn")}}) + if err != nil { + return nil, fmt.Errorf("error: %w while getting information for queue: %s, with url: %s", err, queueName, *url) + } + + return &sqsQueueInfo{arn: *getQueueOutput.Attributes["QueueArn"], url: *url}, nil +} + func (s *snsSqs) getOrCreateQueue(queueName string) (*sqsQueueInfo, error) { - queueArn, ok := s.queues[queueName] + var ( + err error + queueInfo *sqsQueueInfo + ) - if ok { - s.logger.Debugf("Found queue arn for %s: %s", queueName, queueArn) + if cachedQueueInfo, ok := s.queues.Load(queueName); ok { + s.logger.Debugf("Found queue arn for %s: %s", queueName, cachedQueueInfo.(*sqsQueueInfo).arn) - return queueArn, nil + return cachedQueueInfo.(*sqsQueueInfo), nil } // creating queues is idempotent, the names serve as unique keys among a given region. - s.logger.Debugf("No queue arn found for %s\nCreating queue", queueName) + s.logger.Debugf("No SQS queue arn found for %s\nCreating SQS queue", queueName) - queueInfo, err := s.createQueue(queueName) - if err != nil { - s.logger.Errorf("Error creating queue %s: %v", queueName, err) + sanitizedName := nameToAWSSanitizedName(queueName, s.metadata.fifo) + + if !s.metadata.disableEntityManagement { + queueInfo, err = s.createQueue(sanitizedName) + if err != nil { + s.logger.Errorf("Error creating queue %s: %v", queueName, err) - return nil, err + return nil, err + } + } else { + queueInfo, err = s.getQueueArn(sanitizedName) + if err != nil { + s.logger.Errorf("error fetching info for queue %s: %w", queueName, err) + + return nil, err + } } - s.queues[queueName] = queueInfo + s.queues.Store(queueName, queueInfo) + s.logger.Debugf("Created SQS queue: %s: with arn: %s", queueName, queueInfo.arn) return queueInfo, nil } @@ -427,6 +511,76 @@ func (s *snsSqs) getMessageGroupID(req *pubsub.PublishRequest) *string { return &fifoMessageGroupID } +func (s *snsSqs) createSnsSqsSubscription(queueArn, topicArn string) (string, error) { + subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{ + Attributes: nil, + Endpoint: aws.String(queueArn), // create SQS queue per subscription. + Protocol: aws.String("sqs"), + ReturnSubscriptionArn: nil, + TopicArn: aws.String(topicArn), + }) + if err != nil { + wrappedErr := fmt.Errorf("error subscribing to sns topic arn: %s, to queue arn: %s %w", topicArn, queueArn, err) + s.logger.Error(wrappedErr) + + return "", wrappedErr + } + + return *subscribeOutput.SubscriptionArn, nil +} + +func (s *snsSqs) getSnsSqsSubscriptionArn(topicArn string) (string, error) { + listSubscriptionsOutput, err := s.snsClient.ListSubscriptionsByTopic(&sns.ListSubscriptionsByTopicInput{TopicArn: aws.String(topicArn)}) + if err != nil { + return "", fmt.Errorf("error listing subsriptions for topic arn: %v: %w", topicArn, err) + } + + for _, subscription := range listSubscriptionsOutput.Subscriptions { + if *subscription.TopicArn == topicArn { + return *subscription.SubscriptionArn, nil + } + } + + return "", fmt.Errorf("sns sqs subscription not found for topic arn") +} + +func (s *snsSqs) getOrCreateSnsSqsSubscription(queueArn, topicArn string) (string, error) { + var ( + subscriptionArn string + err error + ) + + compositeKey := fmt.Sprintf("%s:%s", queueArn, topicArn) + if cachedSubscriptionArn, ok := s.subscriptions.Load(compositeKey); ok { + s.logger.Debugf("Found subscription of queue arn: %s to topic arn: %s: %s", queueArn, topicArn, cachedSubscriptionArn) + + return cachedSubscriptionArn.(string), nil + } + + s.logger.Debugf("No subscription arn found of queue arn:%s to topic arn: %s\nCreating subscription", queueArn, topicArn) + + if !s.metadata.disableEntityManagement { + subscriptionArn, err = s.createSnsSqsSubscription(queueArn, topicArn) + if err != nil { + s.logger.Errorf("Error creating subscription %s: %v", subscriptionArn, err) + + return "", err + } + } else { + subscriptionArn, err = s.getSnsSqsSubscriptionArn(topicArn) + if err != nil { + s.logger.Errorf("error fetching info for topic arn %s: %w", topicArn, err) + + return "", err + } + } + + s.subscriptions.Store(compositeKey, subscriptionArn) + s.logger.Debugf("Subscribed to topic %s: %s", topicArn, subscriptionArn) + + return subscriptionArn, nil +} + func (s *snsSqs) Publish(req *pubsub.PublishRequest) error { topicArn, err := s.getOrCreateTopic(req.Topic) if err != nil { @@ -497,10 +651,13 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue "message received greater than %v times, deleting this message without further processing", s.metadata.messageRetryLimit) } // ... else, there is no need to actively do something if we reached the limit defined in messageReceiveLimit as the message had - // already been moved to the dead-letters queue by SQS. - if deadLettersQueueInfo != nil && recvCountInt >= s.metadata.messageReceiveLimit { - s.logger.Warnf( - "message received greater than %v times, moving this message without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName) + // already been moved to the dead-letters queue by SQS. meaning, the below condition should not be reached as SQS would not send + // a message if we've already surpassed the s.metadata.messageReceiveLimit value. + if deadLettersQueueInfo != nil && recvCountInt > s.metadata.messageReceiveLimit { + awsErr := fmt.Errorf( + "message received greater than %v times, this message should have been moved without further processing to dead-letters queue: %v", s.metadata.messageReceiveLimit, s.metadata.sqsDeadLettersQueueName) + s.logger.Error(awsErr) + return awsErr } // otherwise try to handle the message. @@ -511,11 +668,18 @@ func (s *snsSqs) handleMessage(message *sqs.Message, queueInfo, deadLettersQueue return fmt.Errorf("error unmarshalling message: %w", err) } - topic := parseTopicArn(messageBody.TopicArn) - topic = s.topicSanitized[topic] + // messageBody.TopicArn can only carry a sanitized topic name as we conform to AWS naming standards. + // for the user to be able to understand the source of the coming message, we'd use the original, + // dirty name to be carried over in the pubsub.NewMessage Topic field. + sanitizedTopic := parseTopicArn(messageBody.TopicArn) + cachedTopic, ok := s.topicsSanitized.Load(sanitizedTopic) + if !ok { + return fmt.Errorf("failed loading topic (sanitized): %s from internal topics cache. SNS topic might be just created", sanitizedTopic) + } + err = handler(context.Background(), &pubsub.NewMessage{ Data: []byte(messageBody.Message), - Topic: topic, + Topic: cachedTopic.(string), }) if err != nil { @@ -535,7 +699,7 @@ func (s *snsSqs) consumeSubscription(queueInfo, deadLettersQueueInfo *sqsQueueIn aws.String(sqs.MessageSystemAttributeNameApproximateReceiveCount), }, MaxNumberOfMessages: aws.Int64(s.metadata.messageMaxNumber), - QueueUrl: &queueInfo.url, + QueueUrl: aws.String(queueInfo.url), VisibilityTimeout: aws.Int64(s.metadata.messageVisibilityTimeout), WaitTimeSeconds: aws.Int64(s.metadata.messageWaitTimeSeconds), }) @@ -695,7 +859,7 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) var sqsSetQueueAttributesInput *sqs.SetQueueAttributesInput sqsSetQueueAttributesInput, derr = s.createQueueAttributesWithDeadLetters(queueInfo, deadLettersQueueInfo) if derr != nil { - wrappedErr := fmt.Errorf("error creatubg queue attributes for dead-letter queue: %w", derr) + wrappedErr := fmt.Errorf("error creating queue attributes for dead-letter queue: %w", derr) s.logger.Error(wrappedErr) return wrappedErr @@ -711,23 +875,13 @@ func (s *snsSqs) Subscribe(req pubsub.SubscribeRequest, handler pubsub.Handler) } // subscription creation is idempotent. Subscriptions are unique by topic/queue. - subscribeOutput, err := s.snsClient.Subscribe(&sns.SubscribeInput{ - Attributes: nil, - Endpoint: &queueInfo.arn, // create SQS queue per subscription. - Protocol: aws.String("sqs"), - ReturnSubscriptionArn: nil, - TopicArn: &topicArn, - }) - if err != nil { - wrappedErr := fmt.Errorf("error subscribing to topic %s: %w", req.Topic, err) + if _, err := s.getOrCreateSnsSqsSubscription(queueInfo.arn, topicArn); err != nil { + wrappedErr := fmt.Errorf("error subscribing topic: %s, to queue: %s, with error: %w", topicArn, queueInfo.arn, err) s.logger.Error(wrappedErr) return wrappedErr } - s.subscriptions = append(s.subscriptions, subscribeOutput.SubscriptionArn) - s.logger.Debugf("Subscribed to topic %s: %v", req.Topic, subscribeOutput) - s.consumeSubscription(queueInfo, deadLettersQueueInfo, handler) return nil