Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: plumb mtlsEndpoint to gRPC Channel provider instead of setting it in EndpointContext. #3386

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
private final HeaderProvider headerProvider;
private final boolean useS2A;
private final String endpoint;
private final String mtlsEndpoint;
// TODO: remove. envProvider currently provides DirectPath environment variable, and is only used
// during initial rollout for DirectPath. This provider will be removed once the DirectPath
// environment is not used.
Expand Down Expand Up @@ -152,6 +153,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsEndpoint = builder.mtlsEndpoint;
this.useS2A = builder.useS2A;
this.mtlsProvider = builder.mtlsProvider;
this.s2aConfigProvider = builder.s2aConfigProvider;
Expand Down Expand Up @@ -229,6 +231,11 @@ public boolean needsEndpoint() {
return endpoint == null;
}

@Override
public boolean needsMtlsEndpoint() {
return mtlsEndpoint == null;
}

/**
* Specify the endpoint the channel should connect to.
*
Expand All @@ -243,6 +250,21 @@ public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

/**
* Specify the mtlsEndpoint the channel should connect to.
*
* <p>The value of {@code mtlsEndpoint} must be of the form {@code host:port}.
*
* @param mtlsEndpoint The mtlsEndpoint to connect to
* @return A new {@link InstantiatingGrpcChannelProvider} with the specified mtlsEndpoint
* configured
*/
@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
return toBuilder().setMtlsEndpoint(mtlsEndpoint).build();
}

/**
* Specify whether or not to use S2A.
*
Expand Down Expand Up @@ -590,8 +612,7 @@ private ManagedChannel createSingleChannel() throws IOException {
}
if (channelCredentials != null) {
// Create the channel using S2A-secured channel credentials.
// {@code endpoint} is set to mtlsEndpoint in {@link EndpointContext} when useS2A is true.
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
builder = Grpc.newChannelBuilder(mtlsEndpoint, channelCredentials);
} else {
// Use default if we cannot initialize channel credentials via DCA or S2A.
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
Expand Down Expand Up @@ -743,6 +764,7 @@ public static final class Builder {
private Executor executor;
private HeaderProvider headerProvider;
private String endpoint;
private String mtlsEndpoint;
private boolean useS2A;
private EnvironmentProvider envProvider;
private SecureSessionAgent s2aConfigProvider = SecureSessionAgent.create();
Expand Down Expand Up @@ -773,6 +795,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.executor = provider.executor;
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.mtlsEndpoint = provider.mtlsEndpoint;
this.useS2A = provider.useS2A;
this.envProvider = provider.envProvider;
this.interceptorProvider = provider.interceptorProvider;
Expand Down Expand Up @@ -843,6 +866,13 @@ public Builder setEndpoint(String endpoint) {
return this;
}

/** Sets the mtlsEndpoint used to reach the service, eg "localhost:8080". */
public Builder setMtlsEndpoint(String mtlsEndpoint) {
validateEndpoint(mtlsEndpoint);
this.mtlsEndpoint = mtlsEndpoint;
return this;
}

Builder setUseS2A(boolean useS2A) {
this.useS2A = useS2A;
return this;
Expand Down Expand Up @@ -876,6 +906,10 @@ public String getEndpoint() {
return endpoint;
}

public String getMtlsEndpoint() {
return mtlsEndpoint;
}

/** The maximum message size allowed to be received on the channel. */
public Builder setMaxInboundMessageSize(Integer max) {
this.maxInboundMessageSize = max;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ void setUp() throws IOException {
when(operationsChannelProvider.getTransportChannel()).thenReturn(transportChannel);
when(operationsChannelProvider.withUseS2A(Mockito.any(boolean.class)))
.thenReturn(operationsChannelProvider);
when(operationsChannelProvider.withMtlsEndpoint(Mockito.any(String.class)))
.thenReturn(operationsChannelProvider);

clock = new FakeApiClock(0L);
executor = RecordingScheduler.create(clock);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,21 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException("LocalChannelProvider doesn't need an mtlsEndpoint");
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
// Overriden for technical reasons. This method is a no-op for LocalChannelProvider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,21 @@ public boolean needsEndpoint() {
return endpoint == null;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
return toBuilder().setEndpoint(endpoint).build();
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
return this;
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
return this;
Expand Down
10 changes: 10 additions & 0 deletions gax-java/gax/clirr-ignored-differences.xml
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,14 @@
<className>com/google/api/gax/rpc/TransportChannelProvider</className>
<method>* withUseS2A(*)</method>
</difference>
<difference>
<differenceType>7012</differenceType>
<className>com/google/api/gax/rpc/TransportChannelProvider</className>
<method>* needsMtlsEndpoint()</method>
</difference>
<difference>
<differenceType>7012</differenceType>
<className>com/google/api/gax/rpc/TransportChannelProvider</className>
<method>* withMtlsEndpoint(*)</method>
</difference>
</differences>
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ public static ClientContext create(StubSettings settings) throws IOException {
if (transportChannelProvider.needsEndpoint()) {
transportChannelProvider = transportChannelProvider.withEndpoint(endpoint);
}
if (transportChannelProvider.needsMtlsEndpoint()) {
transportChannelProvider =
transportChannelProvider.withMtlsEndpoint(endpointContext.mtlsEndpoint());
}
transportChannelProvider = transportChannelProvider.withUseS2A(endpointContext.useS2A());
TransportChannel transportChannel = transportChannelProvider.getTransportChannel();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,6 @@ private String determineUniverseDomain() {

/** Determines the fully resolved endpoint and universe domain values */
private String determineEndpoint() throws IOException {
if (shouldUseS2A()) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of renaming this to canUseS2A(), since that's what the result of this function is really indicating.

For example, it is possible that this function returns true and canUseDirectPath in InstantiatingGrpcChannelProvider` returns true, and we therefore end up not using S2A.

return mtlsEndpoint();
}

MtlsProvider mtlsProvider = mtlsProvider() == null ? new MtlsProvider() : mtlsProvider();
// TransportChannelProvider's endpoint will override the ClientSettings' endpoint
String customEndpoint =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,23 @@ public boolean needsEndpoint() {
return false;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public TransportChannelProvider withEndpoint(String endpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an endpoint");
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
throw new UnsupportedOperationException(
"FixedTransportChannelProvider doesn't need an mtlsEndpoint");
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) throws UnsupportedOperationException {
// Overriden for technical reasons. This method is a no-op for FixedTransportChannelProvider.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,23 @@ public interface TransportChannelProvider {
/** True if the TransportProvider has no endpoint set. */
boolean needsEndpoint();

/** True if the TransportProvider has no mtlsEndpoint set. */
boolean needsMtlsEndpoint();

/**
* Sets the endpoint to use when constructing a new {@link TransportChannel}.
*
* <p>This method should only be called if {@link #needsEndpoint()} returns true.
*/
TransportChannelProvider withEndpoint(String endpoint);

/**
* Sets the mtlsEndpoint to use when constructing a new {@link TransportChannel}.
*
* <p>This method should only be called if {@link #needsMtlsEndpoint()} returns true.
*/
TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint);

/** Sets whether to use S2A when constructing a new {@link TransportChannel}. */
default TransportChannelProvider withUseS2A(boolean useS2A) {
throw new UnsupportedOperationException("S2A is not supported");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,11 @@ public boolean needsEndpoint() {
return true;
}

@Override
public boolean needsMtlsEndpoint() {
return false;
}

@Override
public String getEndpoint() {
return endpoint;
Expand All @@ -195,6 +200,17 @@ public TransportChannelProvider withEndpoint(String endpoint) {
endpoint);
}

@Override
public TransportChannelProvider withMtlsEndpoint(String mtlsEndpoint) {
return new FakeTransportProvider(
this.transport,
this.executor,
this.shouldAutoClose,
this.headers,
this.credentials,
this.endpoint);
}

@Override
public TransportChannelProvider withUseS2A(boolean useS2A) {
return new FakeTransportProvider(
Expand Down
Loading