Skip to content

Commit

Permalink
plumb mtls endpoint to grpc channel provider.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmehta19 committed Nov 19, 2024
1 parent 1080de4 commit 0f517a9
Show file tree
Hide file tree
Showing 9 changed files with 99 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
private final HeaderProvider headerProvider;
private final boolean useS2A;
private final String endpoint;
private final String mtlsEndpoint;
// TODO: remove. envProvider currently provides DirectPath environment variable, and is only used
// during initial rollout for DirectPath. This provider will be removed once the DirectPath
// environment is not used.
Expand Down Expand Up @@ -152,6 +153,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsEndpoint = builder.mtlsEndpoint;
this.useS2A = builder.useS2A;
this.mtlsProvider = builder.mtlsProvider;
this.s2aConfigProvider = builder.s2aConfigProvider;
Expand Down Expand Up @@ -229,6 +231,11 @@ public boolean needsEndpoint() {
return endpoint == null;
}

@Override
public boolean needsMtlsEndpoint() {
return mtlsEndpoint == null;
}

/**
* Specify the endpoint the channel should connect to.
*
Expand All @@ -243,6 +250,21 @@ public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

/**
* Specify the mtlsEndpoint the channel should connect to.
*
* <p>The value of {@code mtlsEndpoint} must be of the form {@code host:port}.
*
* @param mtlsEndpoint The mtlsEndpoint to connect to
* @return A new {@link InstantiatingGrpcChannelProvider} with the specified mtlsEndpoint
* configured
*/
@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
return toBuilder().setMtlsEndpoint(mtlsEndpoint).build();
}

/**
* Specify whether or not to use S2A.
*
Expand Down Expand Up @@ -590,8 +612,7 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
// {@code endpoint} is set to mtlsEndpoint in {@link EndpointContext} when useS2A is true.
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
builder = Grpc.newChannelBuilder(mtlsEndpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
Expand Down Expand Up @@ -743,6 +764,7 @@ public static final class Builder {
private Executor executor;
private HeaderProvider headerProvider;
private String endpoint;
private String mtlsEndpoint;
private boolean useS2A;
private EnvironmentProvider envProvider;
private SecureSessionAgent s2aConfigProvider = SecureSessionAgent.create();
Expand Down Expand Up @@ -773,6 +795,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.executor = provider.executor;
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.mtlsEndpoint = provider.mtlsEndpoint;
this.useS2A = provider.useS2A;
this.envProvider = provider.envProvider;
this.interceptorProvider = provider.interceptorProvider;
Expand Down Expand Up @@ -843,6 +866,13 @@ public Builder setEndpoint(String endpoint) {
return this;
}

/** Sets the mtlsEndpoint used to reach the service, eg "localhost:8080". */
public Builder setMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
this.mtlsEndpoint = mtlsEndpoint;
return this;
}

Builder setUseS2A(boolean useS2A) {
this.useS2A = useS2A;
return this;
Expand Down Expand Up @@ -876,6 +906,10 @@ public String getEndpoint() {
return endpoint;
}

public String getMtlsEndpoint() {
return mtlsEndpoint;
}

/** The maximum message size allowed to be received on the channel. */
public Builder setMaxInboundMessageSize(Integer max) {
this.maxInboundMessageSize = max;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void setUp() throws IOException {
when(operationsChannelProvider.getTransportChannel()).thenReturn(transportChannel);
when(operationsChannelProvider.withUseS2A(Mockito.any(boolean.class)))
.thenReturn(operationsChannelProvider);
when(operationsChannelProvider.withMtlsEndpoint(Mockito.any(String.class)))
.thenReturn(operationsChannelProvider);

clock = new FakeApiClock(0L);
executor = RecordingScheduler.create(clock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,21 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an mtlsEndpoint");
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
// Overriden for technical reasons. This method is a no-op for LocalChannelProvider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,21 @@ public boolean needsEndpoint() {
return endpoint == null;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
return this;
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.needsEndpoint()) {
transportChannelProvider = transportChannelProvider.withEndpoint(endpoint);
}
if (transportChannelProvider.needsMtlsEndpoint()) {
transportChannelProvider =
transportChannelProvider.withMtlsEndpoint(endpointContext.mtlsEndpoint());
}
transportChannelProvider = transportChannelProvider.withUseS2A(endpointContext.useS2A());
TransportChannel transportChannel = transportChannelProvider.getTransportChannel();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,6 @@ private String determineUniverseDomain() {

/** Determines the fully resolved endpoint and universe domain values */
private String determineEndpoint() throws IOException {
if (shouldUseS2A()) {
return mtlsEndpoint();
}

MtlsProvider mtlsProvider = mtlsProvider() == null ? new MtlsProvider() : mtlsProvider();
// TransportChannelProvider's endpoint will override the ClientSettings' endpoint
String customEndpoint =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,23 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an mtlsEndpoint");
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) throws UnsupportedOperationException {
// Overriden for technical reasons. This method is a no-op for FixedTransportChannelProvider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,23 @@ public interface TransportChannelProvider {
/** True if the TransportProvider has no endpoint set. */
boolean needsEndpoint();

/** True if the TransportProvider has no mtlsEndpoint set. */
boolean needsMtlsEndpoint();

/**
* Sets the endpoint to use when constructing a new {@link TransportChannel}.
*
* <p>This method should only be called if {@link #needsEndpoint()} returns true.
*/
TransportChannelProvider withEndpoint(String endpoint);

/**
* Sets the mtlsEndpoint to use when constructing a new {@link TransportChannel}.
*
* <p>This method should only be called if {@link #needsMtlsEndpoint()} returns true.
*/
TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint);

/** Sets whether to use S2A when constructing a new {@link TransportChannel}. */
default TransportChannelProvider withUseS2A(boolean useS2A) {
throw new UnsupportedOperationException("S2A is not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ public boolean needsEndpoint() {
return true;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public String getEndpoint() {
return endpoint;
Expand All @@ -195,6 +200,17 @@ public TransportChannelProvider withEndpoint(String endpoint) {
endpoint);
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
return new FakeTransportProvider(
this.transport,
this.executor,
this.shouldAutoClose,
this.headers,
this.credentials,
this.endpoint);
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
return new FakeTransportProvider(
Expand Down

0 comments on commit 0f517a9

Please sign in to comment.