Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix io.grpc.servlet.jakarta.ServletAdapter#getHeaders #2

Open
wants to merge 5 commits into
base: grpc-web-servlet-filter
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.grpc.ExperimentalApi;
import io.grpc.Grpc;
import io.grpc.InternalLogId;
import io.grpc.InternalMetadata;
import io.grpc.Metadata;
import io.grpc.ServerStreamTracer;
import io.grpc.Status;
Expand All @@ -38,12 +37,12 @@
import java.net.InetSocketAddress;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.List;
import java.util.Locale;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;

import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -180,25 +179,26 @@ private static Metadata getHeaders(HttpServletRequest req) {
Enumeration<String> headerNames = req.getHeaderNames();
checkNotNull(
headerNames, "Servlet container does not allow HttpServletRequest.getHeaderNames()");
List<byte[]> byteArrays = new ArrayList<>();
Metadata metadata = new Metadata();
while (headerNames.hasMoreElements()) {
String headerName = headerNames.nextElement();
Enumeration<String> values = req.getHeaders(headerName);
if (values == null) {
continue;
}
headerName = headerName.toLowerCase(Locale.ROOT);
boolean isBinaryHeader = headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX);
while (values.hasMoreElements()) {
String value = values.nextElement();
if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
byteArrays.add(headerName.getBytes(StandardCharsets.US_ASCII));
byteArrays.add(BaseEncoding.base64().decode(value));
if (isBinaryHeader) {
metadata.put(Metadata.Key.of(headerName, Metadata.BINARY_BYTE_MARSHALLER),
BaseEncoding.base64().decode(value));
} else {
byteArrays.add(headerName.getBytes(StandardCharsets.US_ASCII));
byteArrays.add(value.getBytes(StandardCharsets.US_ASCII));
metadata.put(Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER), value);
}
}
}
return InternalMetadata.newMetadata(byteArrays.toArray(new byte[][] {}));
return metadata;
}

// This method must use HttpRequest#getRequestURL or HttpUtils#getRequestURL, both of which
Expand Down Expand Up @@ -273,6 +273,7 @@ private static final class GrpcReadListener implements ReadListener {
final AsyncContext asyncCtx;
final ServletInputStream input;
final InternalLogId logId;
private final AtomicBoolean closed = new AtomicBoolean(false);

GrpcReadListener(
ServletServerStream stream,
Expand Down Expand Up @@ -315,6 +316,11 @@ public void onDataAvailable() throws IOException {
@Override
public void onAllDataRead() {
logger.log(FINE, "[{0}] onAllDataRead", logId);
if (!closed.compareAndSet(false, true)) {
// https://github.com/eclipse/jetty.project/issues/8405
logger.log(FINE, "[{0}] onAllDataRead already called, skipping this one", logId);
return;
}
stream.transportState().runOnTransportThread(
() -> stream.transportState().inboundDataReceived(ReadableBuffers.empty(), true));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package io.grpc.servlet.jakarta.web;

import jakarta.servlet.AsyncContext;
import jakarta.servlet.AsyncListener;
import jakarta.servlet.ServletContext;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;

/**
* Util class to allow the complete() call to get some work done (writing trailers as a payload) before calling the
* actual container implementation. The container will finish closing the stream before invoking the async listener and
* formally informing the filter that the stream has closed, making this our last chance to intercept the closing of the
* stream before it happens.
*/
public class DelegatingAsyncContext implements AsyncContext {
private final AsyncContext delegate;

public DelegatingAsyncContext(AsyncContext delegate) {
this.delegate = delegate;
}

@Override
public ServletRequest getRequest() {
return delegate.getRequest();
}

@Override
public ServletResponse getResponse() {
return delegate.getResponse();
}

@Override
public boolean hasOriginalRequestAndResponse() {
return delegate.hasOriginalRequestAndResponse();
}

@Override
public void dispatch() {
delegate.dispatch();
}

@Override
public void dispatch(String path) {
delegate.dispatch(path);
}

@Override
public void dispatch(ServletContext context, String path) {
delegate.dispatch(context, path);
}

@Override
public void complete() {
delegate.complete();
}

@Override
public void start(Runnable run) {
delegate.start(run);
}

@Override
public void addListener(AsyncListener listener) {
delegate.addListener(listener);
}

@Override
public void addListener(AsyncListener listener, ServletRequest servletRequest,
ServletResponse servletResponse) {
delegate.addListener(listener, servletRequest, servletResponse);
}

@Override
public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
return delegate.createListener(clazz);
}

@Override
public void setTimeout(long timeout) {
delegate.setTimeout(timeout);
}

@Override
public long getTimeout() {
return delegate.getTimeout();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package io.grpc.servlet.jakarta.web;

import io.grpc.internal.GrpcUtil;
import jakarta.servlet.AsyncContext;
import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.ServletRequest;
import jakarta.servlet.ServletResponse;
import jakarta.servlet.http.HttpFilter;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletRequestWrapper;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpServletResponseWrapper;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.function.Supplier;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.regex.Pattern;

/**
* Servlet filter that translates grpc-web on the fly to match what is expected by GrpcServlet. This work is done
* in-process with no addition copies to the request or response data - only the content type header and the trailer
* content is specially treated at this time.
*
* Note that grpc-web-text is not yet supported.
*/
public class GrpcWebFilter extends HttpFilter {
private static final Logger logger = Logger.getLogger(GrpcWebFilter.class.getName());

public static final String CONTENT_TYPE_GRPC_WEB = GrpcUtil.CONTENT_TYPE_GRPC + "-web";

@Override
public void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (isGrpcWeb(request)) {
// wrap the request and response to paper over the grpc-web details
GrpcWebHttpResponse wrappedResponse = new GrpcWebHttpResponse(response);
HttpServletRequestWrapper wrappedRequest = new HttpServletRequestWrapper(request) {
@Override
public String getContentType() {
// Adapt the content-type to replace grpc-web with grpc
return super.getContentType().replaceFirst(Pattern.quote(CONTENT_TYPE_GRPC_WEB),
GrpcUtil.CONTENT_TYPE_GRPC);
}

@Override
public AsyncContext startAsync() throws IllegalStateException {
return startAsync(this, wrappedResponse);
}

@Override
public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
throws IllegalStateException {
AsyncContext delegate = super.startAsync(servletRequest, servletResponse);
return new DelegatingAsyncContext(delegate) {
@Override
public void complete() {
// Write any trailers out to the output stream as a payload, since grpc-web doesn't
// use proper trailers.
try {
if (wrappedResponse.trailers != null) {
Map<String, String> map = wrappedResponse.trailers.get();
if (map != null) {
// write a payload, even for an empty set of trailers, but not for
// the absence of trailers.
int trailerLength = map.entrySet().stream()
.mapToInt(e -> e.getKey().length() + e.getValue().length() + 4).sum();
ByteBuffer payload = ByteBuffer.allocate(5 + trailerLength);
payload.put((byte) 0x80);
payload.putInt(trailerLength);
for (Map.Entry<String, String> entry : map.entrySet()) {
payload.put(entry.getKey().getBytes(StandardCharsets.US_ASCII));
payload.put(": ".getBytes(StandardCharsets.US_ASCII));
payload.put(entry.getValue().getBytes(StandardCharsets.US_ASCII));
payload.put("\r\n".getBytes(StandardCharsets.US_ASCII));
}
wrappedResponse.getOutputStream().write(payload.array());
}
}
} catch (IOException e) {
// complete() should not throw, but instead just log the error. In this case,
// the connection has likely been lost, so there is no way to send the trailers,
// so we just let the exception slide.
logger.log(Level.FINE, "Error sending grpc-web trailers", e);
}

// Let the superclass complete the stream so we formally close it
super.complete();
}
};
}
};

chain.doFilter(wrappedRequest, wrappedResponse);
} else {
chain.doFilter(request, response);
}
}

private static boolean isGrpcWeb(ServletRequest request) {
return request.getContentType() != null && request.getContentType().startsWith(CONTENT_TYPE_GRPC_WEB);
}

// Technically we should throw away content-length too, but the impl won't care
public static class GrpcWebHttpResponse extends HttpServletResponseWrapper {
private Supplier<Map<String, String>> trailers;

public GrpcWebHttpResponse(HttpServletResponse response) {
super(response);
}

@Override
public void setContentType(String type) {
// Adapt the content-type to be grpc-web
super.setContentType(
type.replaceFirst(Pattern.quote(GrpcUtil.CONTENT_TYPE_GRPC), CONTENT_TYPE_GRPC_WEB));
}

// intercept trailers and write them out as a message just before we complete
@Override
public void setTrailerFields(Supplier<Map<String, String>> supplier) {
trailers = supplier;
}

@Override
public Supplier<Map<String, String>> getTrailerFields() {
return trailers;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.deephaven.ssl.config.TrustJdk;
import io.deephaven.ssl.config.impl.KickstartUtils;
import io.grpc.servlet.web.websocket.WebSocketServerStream;
import io.grpc.servlet.jakarta.web.GrpcWebFilter;
import jakarta.servlet.DispatcherType;
import jakarta.websocket.server.ServerEndpointConfig;
import nl.altindag.ssl.SSLFactory;
Expand Down Expand Up @@ -78,6 +79,9 @@ public JettyBackedGrpcServer(
// Direct jetty all use this configuration as the root application
context.setContextPath("/");

// Handle grpc-web connections, translate to vanilla grpc
context.addFilter(new FilterHolder(new GrpcWebFilter()), "/*", EnumSet.noneOf(DispatcherType.class));

// Wire up the provided grpc filter
context.addFilter(new FilterHolder(filter), "/*", EnumSet.noneOf(DispatcherType.class));

Expand Down