Skip to content

Commit

Permalink
[ALS-6074] Fix uploader auth timeouts for instance profile
Browse files Browse the repository at this point in the history
- Refresh auth when rebuilding client
  • Loading branch information
Luke Sikina authored and Luke-Sikina committed Jul 10, 2024
1 parent 0c0653f commit 8718216
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.util.StringUtils;
import software.amazon.awssdk.auth.credentials.*;
import software.amazon.awssdk.http.SdkHttpClient;
Expand All @@ -27,15 +28,6 @@
public class AWSConfiguration {
private static final Logger LOG = LoggerFactory.getLogger(AWSConfiguration.class);

@Value("${aws.s3.access_key_secret:}")
private String secret;

@Value("${aws.s3.access_key_id:}")
private String key;

@Value("${aws.s3.session_token:}")
private String token;

@Value("${aws.s3.institution:}")
private List<String> institutions;

Expand All @@ -57,51 +49,6 @@ public class AWSConfiguration {
@Autowired(required = false)
private SdkHttpClient sdkHttpClient;

@Value("${http.proxyUser:}")
private String proxyUser;

@Bean
@ConditionalOnProperty(name = "production", havingValue = "true")
public StsClient stsClients(
@Autowired AwsCredentials credentials,
@Autowired StsClientBuilder stsClientBuilder
) {
StsClientBuilder builder = stsClientBuilder
.region(Region.US_EAST_1)
.credentialsProvider(StaticCredentialsProvider.create(credentials));

if (StringUtils.hasLength(proxyUser)) {
builder.httpClient(sdkHttpClient);
}

return builder.build();
}

@Bean
@ConditionalOnProperty(name = "aws.authentication.method", havingValue = "user")
AwsCredentials credentials() {
LOG.info("Authentication method is user. Attempting to resolve user credentials.");
if (Strings.isBlank(key)) {
LOG.error("No AWS key. Can't create client. Exiting");
context.close();
}
if (Strings.isBlank(secret)) {
LOG.error("No AWS secret. Can't create client. Exiting");
context.close();
}
if (Strings.isBlank(token)) {
return AwsBasicCredentials.create(key, secret);
} else {
return AwsSessionCredentials.create(key, secret, token);
}
}
@Bean
@ConditionalOnProperty(name = "aws.authentication.method", havingValue = "instance-profile")
AwsCredentials ipCredentials() {
LOG.info("Authentication method is instance-profile. Attempting to resolve instance profile credentials.");
return InstanceProfileCredentialsProvider.create().resolveCredentials();
}

@Bean
@ConditionalOnProperty(name = "production", havingValue = "true")
Map<String, SiteAWSInfo> roleARNs() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.apache.logging.log4j.util.Strings;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Service;
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.InstanceProfileCredentialsProvider;

@Service
public class AWSCredentialsService {

private static final Logger LOG = LoggerFactory.getLogger(AWSCredentialsService.class);

private final String authMethod;
private final String secret;
private final String key;
private final String token;
private final ConfigurableApplicationContext context;


private AwsCredentials credentials;

@Autowired
public AWSCredentialsService(
@Value("${aws.authentication.method:}") String authMethod,
@Value("${aws.s3.access_key_secret:}") String secret,
@Value("${aws.s3.access_key_id:}") String key,
@Value("${aws.s3.session_token:}") String token,
ConfigurableApplicationContext context
) {
this.authMethod = authMethod;
this.secret = secret;
this.key = key;
this.token = token;
this.context = context;
}

public AwsCredentials constructCredentials() {
//noinspection SwitchStatementWithTooFewBranches
return switch (authMethod) {
case "instance-profile" -> createInstanceProfileBasedCredentials();
default -> createUserBasedCredentials();
};
}

private AwsCredentials createUserBasedCredentials() {
LOG.info("Authentication method is user. Attempting to resolve user credentials.");
if (Strings.isBlank(key)) {
LOG.error("No AWS key. Can't create client. Exiting");
context.close();
}
if (Strings.isBlank(secret)) {
LOG.error("No AWS secret. Can't create client. Exiting");
context.close();
}
if (Strings.isBlank(token)) {
return AwsBasicCredentials.create(key, secret);
} else {
return AwsSessionCredentials.create(key, secret, token);
}
}

private AwsCredentials createInstanceProfileBasedCredentials() {
LOG.info("Authentication method is instance-profile. Attempting to resolve instance profile credentials.");
return InstanceProfileCredentialsProvider.create().resolveCredentials();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,16 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
import software.amazon.awssdk.http.SdkHttpClient;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.S3ClientBuilder;
import software.amazon.awssdk.services.sts.StsClient;
import software.amazon.awssdk.services.sts.StsClientBuilder;
import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
import software.amazon.awssdk.services.sts.model.Credentials;
Expand Down Expand Up @@ -47,9 +50,6 @@ public class SelfRefreshingS3Client {
@Autowired
private ConfigurableApplicationContext context;

@Autowired
private StsClient stsClient;

@Autowired
private Map<String, SiteAWSInfo> roleARNs;

Expand All @@ -59,13 +59,34 @@ public class SelfRefreshingS3Client {
@Autowired(required = false)
private SdkHttpClient sdkHttpClient;

@Autowired
private AWSCredentialsService credentialsService;

@Autowired
private StsClientBuilder stsClientBuilder;

@Value("${http.proxyUser:}")
private String proxyUser;

@PostConstruct
private void refreshClient() {
locks = roleARNs.keySet().stream()
.collect(Collectors.toMap(Function.identity(), (s) -> new ReentrantReadWriteLock()));
roleARNs.keySet().stream().parallel().forEach(this::refreshClient);
}

private StsClient createStsClient() {
StsClientBuilder builder = stsClientBuilder
.region(Region.US_EAST_1)
.credentialsProvider(StaticCredentialsProvider.create(credentialsService.constructCredentials()));

if (StringUtils.hasLength(proxyUser)) {
builder.httpClient(sdkHttpClient);
}

return builder.build();
}

// exposed for testing
void refreshClient(String siteName) {
LOG.info("Starting client refresh for {}", siteName);
Expand All @@ -83,7 +104,7 @@ void refreshClient(String siteName) {
.externalId(roleARNs.get(siteName).externalId())
.durationSeconds(60*60) // 1 hour
.build();
AssumeRoleResponse assumeRoleResponse = stsClient.assumeRole(roleRequest);
AssumeRoleResponse assumeRoleResponse = createStsClient().assumeRole(roleRequest);
if (assumeRoleResponse.credentials() == null ) {
LOG.error("Error assuming role, no credentials returned! Exiting!");
statusService.setClientStatus("error");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,44 +33,6 @@ class AWSConfigurationTest {
@InjectMocks
AWSConfiguration subject;

@Test
void shouldCreateCredentials() {
ReflectionTestUtils.setField(subject, "secret", "s1");
ReflectionTestUtils.setField(subject, "key", "k1");
ReflectionTestUtils.setField(subject, "token", "t1");

AwsCredentials actual = subject.credentials();
AwsSessionCredentials expected = AwsSessionCredentials.create("k1", "s1", "t1");

Assertions.assertEquals(expected, actual);
}

@Test
void shouldNotCreateCredentials() {
ReflectionTestUtils.setField(subject, "secret", "");
ReflectionTestUtils.setField(subject, "key", "k1");
ReflectionTestUtils.setField(subject, "token", "t1");

subject.credentials();

Mockito.verify(context, Mockito.times(1)).close();
}

@Test
void shouldCreateClients() {
AwsSessionCredentials credentials = AwsSessionCredentials.create("k1", "s1", "t1");
Mockito.when(stsClientBuilder.region(Region.US_EAST_1))
.thenReturn(stsClientBuilder);
Mockito.when(stsClientBuilder.credentialsProvider(Mockito.any()))
.thenReturn(stsClientBuilder);
Mockito.when(stsClientBuilder.build())
.thenReturn(stsClient);

StsClient actual = subject.stsClients(credentials, stsClientBuilder);

Assertions.assertEquals(stsClient, actual);
}

@Test
void shouldCreateRoles() {
ReflectionTestUtils.setField(subject, "institutions", List.of("i1", "i2"));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package edu.harvard.dbmi.avillach.dataupload.aws;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.mockito.InjectMocks;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.mock.mockito.MockBean;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.test.util.ReflectionTestUtils;
import software.amazon.awssdk.auth.credentials.AwsCredentials;
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;

import static org.junit.jupiter.api.Assertions.*;

@SpringBootTest
class AWSCredentialsServiceTest {

@Mock
ConfigurableApplicationContext context;

@InjectMocks
AWSCredentialsService subject;

@Test
void shouldCreateCredentials() {
ReflectionTestUtils.setField(subject, "authMethod", "user");
ReflectionTestUtils.setField(subject, "secret", "s1");
ReflectionTestUtils.setField(subject, "key", "k1");
ReflectionTestUtils.setField(subject, "token", "t1");

AwsCredentials actual = subject.constructCredentials();
AwsSessionCredentials expected = AwsSessionCredentials.create("k1", "s1", "t1");

Assertions.assertEquals(expected, actual);
}

@Test
void shouldNotCreateCredentials() {
ReflectionTestUtils.setField(subject, "authMethod", "user");
ReflectionTestUtils.setField(subject, "secret", "");
ReflectionTestUtils.setField(subject, "key", "k1");
ReflectionTestUtils.setField(subject, "token", "t1");

subject.constructCredentials();

Mockito.verify(context, Mockito.times(1)).close();
}
}

0 comments on commit 8718216

Please sign in to comment.