Skip to content

Commit

Permalink
http (feature): Request.withDest(ServerAddrress) to switch the destin…
Browse files Browse the repository at this point in the history
…ation host in http client (#3479)

- **http (feature): Enable setting a destionation address for each HTTP
request**
- **Add dest to client logs**

Closes #3469
Closes #3471
  • Loading branch information
xerial authored Apr 8, 2024
1 parent 1f70229 commit edc39e7
Show file tree
Hide file tree
Showing 14 changed files with 98 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class RPCRequestBenchmark extends LogSupport {
new HttpChannel {
private val responseCodec = MessageCodec.of[GreeterResponse]

override val destination: ServerAddress = ServerAddress("localhost:8080")

override def send(req: HttpMessage.Request, channelConfig: HttpChannelConfig): HttpMessage.Response = {
val ret = emptyServer.hello(req.message.toContentString)
Http.response(HttpStatus.Ok_200).withJson(responseCodec.toJson(ret))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import wvlet.log.LogSupport
import java.io.IOException
import java.util.concurrent.TimeUnit

class OkHttpChannel(serverAddress: ServerAddress, config: HttpClientConfig) extends HttpChannel with LogSupport {
class OkHttpChannel(val destination: ServerAddress, config: HttpClientConfig) extends HttpChannel with LogSupport {
private[this] val client = {
var builder = new okhttp3.OkHttpClient.Builder()
.readTimeout(config.readTimeout.toMillis, TimeUnit.MILLISECONDS)
Expand Down Expand Up @@ -56,15 +56,21 @@ class OkHttpChannel(serverAddress: ServerAddress, config: HttpClientConfig) exte
newClient
}

override def send(req: HttpMessage.Request, channelConfig: HttpChannelConfig): HttpMessage.Response = {
override def send(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): HttpMessage.Response = {
val request: okhttp3.Request = convertRequest(req)

val newClient = prepareClient(channelConfig)
val response = newClient.newCall(request).execute()
response.toHttpResponse
}

override def sendAsync(req: HttpMessage.Request, channelConfig: HttpChannelConfig): Rx[HttpMessage.Response] = {
override def sendAsync(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): Rx[HttpMessage.Response] = {
val request: okhttp3.Request = convertRequest(req)
val newClient = prepareClient(channelConfig)
val v = Rx.variable[Option[HttpMessage.Response]](None)
Expand Down Expand Up @@ -93,7 +99,8 @@ class OkHttpChannel(serverAddress: ServerAddress, config: HttpClientConfig) exte
}

val url = HttpUrl
.get(serverAddress.uri).newBuilder()
.get(request.dest.getOrElse(destination).uri)
.newBuilder()
.encodedPath(request.path)
.encodedQuery(queryParams)
.build()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,27 @@ import scala.util.{Failure, Success, Try}
* @param serverAddress
* @param config
*/
class JSFetchChannel(serverAddress: ServerAddress, config: HttpClientConfig) extends HttpChannel with LogSupport {
class JSFetchChannel(val destination: ServerAddress, config: HttpClientConfig) extends HttpChannel with LogSupport {
private[client] implicit val executionContext: ExecutionContext = Compat.defaultExecutionContext

override def close(): Unit = {
// nothing to do
}

override def send(req: HttpMessage.Request, channelConfig: HttpChannelConfig): HttpMessage.Response = {
override def send(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): HttpMessage.Response = {
// Blocking call cannot be supported in JS
???
}

override def sendAsync(request: HttpMessage.Request, channelConfig: HttpChannelConfig): Rx[HttpMessage.Response] = {
override def sendAsync(
request: HttpMessage.Request,
channelConfig: HttpChannelConfig
): Rx[HttpMessage.Response] = {
val path = if (request.uri.startsWith("/")) request.uri else s"/${request.uri}"
val uri = s"${serverAddress.uri}${path}"
val uri = s"${request.dest.getOrElse(destination).uri}${path}"

val req = new org.scalajs.dom.RequestInit {
method = request.method match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.concurrent.{ExecutionContext, Promise, TimeoutException}
import scala.scalajs.js.typedarray.{ArrayBuffer, TypedArrayBuffer}
import scala.util.Try

class JSHttpClientChannel(serverAddress: ServerAddress, private[client] val config: HttpClientConfig)
class JSHttpClientChannel(val destination: ServerAddress, private[client] val config: HttpClientConfig)
extends HttpChannel
with LogSupport {

Expand All @@ -35,7 +35,10 @@ class JSHttpClientChannel(serverAddress: ServerAddress, private[client] val conf
// nothing to do
}

override def send(request: HttpMessage.Request, channelConfig: HttpChannelConfig): HttpMessage.Response = ???
override def send(
request: HttpMessage.Request,
channelConfig: HttpChannelConfig
): HttpMessage.Response = ???

override def sendAsync(
request: HttpMessage.Request,
Expand All @@ -45,7 +48,7 @@ class JSHttpClientChannel(serverAddress: ServerAddress, private[client] val conf
val xhr = new dom.XMLHttpRequest()

val path = if (request.uri.startsWith("/")) request.uri else s"/${request.uri}"
val uri = s"${serverAddress.uri}${path}"
val uri = s"${request.dest.getOrElse(destination).uri}${path}"

trace(s"Sending request: ${request}: ${uri}")
xhr.open(request.method, uri)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ import scala.jdk.CollectionConverters.*

/**
* Http connection implementation using Http Client of Java 11
* @param serverAddress
* @param destination
* @param config
*/
class JavaHttpClientChannel(serverAddress: ServerAddress, private[http] val config: HttpClientConfig)
class JavaHttpClientChannel(val destination: ServerAddress, private[http] val config: HttpClientConfig)
extends HttpChannel
with LogSupport {
private val javaHttpClient: HttpClient = initClient(config)
Expand Down Expand Up @@ -63,7 +63,7 @@ class JavaHttpClientChannel(serverAddress: ServerAddress, private[http] val conf

override def send(req: Request, channelConfig: HttpChannelConfig): Response = {
// New Java's HttpRequest is immutable, so we can reuse the same request instance
val httpRequest = buildRequest(serverAddress, req, channelConfig)
val httpRequest = buildRequest(req, channelConfig)
val httpResponse: HttpResponse[InputStream] =
javaHttpClient.send(httpRequest, BodyHandlers.ofInputStream())

Expand All @@ -73,7 +73,7 @@ class JavaHttpClientChannel(serverAddress: ServerAddress, private[http] val conf
override def sendAsync(req: Request, channelConfig: HttpChannelConfig): Rx[Response] = {
val v = Rx.variable[Option[Response]](None)
try {
val httpRequest = buildRequest(serverAddress, req, channelConfig)
val httpRequest = buildRequest(req, channelConfig)
javaHttpClient
.sendAsync(httpRequest, BodyHandlers.ofInputStream())
.thenAccept(new Consumer[HttpResponse[InputStream]] {
Expand All @@ -96,11 +96,10 @@ class JavaHttpClientChannel(serverAddress: ServerAddress, private[http] val conf
}

private def buildRequest(
serverAddress: ServerAddress,
request: Request,
channelConfig: HttpChannelConfig
): HttpRequest = {
val uri = s"${serverAddress.uri}${if (request.uri.startsWith("/")) request.uri
val uri = s"${request.dest.getOrElse(destination).uri}${if (request.uri.startsWith("/")) request.uri
else s"/${request.uri}"}"

val requestBuilder = HttpRequest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ import scala.jdk.CollectionConverters.*
* @param serverAddress
* @param config
*/
class URLConnectionChannel(serverAddress: ServerAddress, config: HttpClientConfig) extends HttpChannel {
class URLConnectionChannel(val destination: ServerAddress, config: HttpClientConfig) extends HttpChannel {
override def send(request: Request, channelConfig: HttpChannelConfig): Response = {
val url = s"${serverAddress.uri}${if (request.uri.startsWith("/")) request.uri
val url = s"${request.dest.getOrElse(destination).uri}${if (request.uri.startsWith("/")) request.uri
else s"/${request.uri}"}"

val conn0: HttpURLConnection =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,21 @@ import wvlet.log.LogSupport
class HttpClientLoggingFilterTest extends AirSpec {

class DummyHttpChannel extends HttpChannel with LogSupport {
override def send(req: HttpMessage.Request, channelConfig: HttpChannelConfig): HttpMessage.Response = {
override def destination: ServerAddress = ServerAddress("localhost:8080")
override def send(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): HttpMessage.Response = {
Http
.response(HttpStatus.Ok_200).withJson("""{"message":"hello"}""")
.withHeader(HttpHeader.xAirframeRPCStatus, RPCStatus.SUCCESS_S0.code.toString)
}

override def sendAsync(req: HttpMessage.Request, channelConfig: HttpChannelConfig): Rx[HttpMessage.Response] = ???
override def close(): Unit = {}
override def sendAsync(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): Rx[HttpMessage.Response] = ???
override def close(): Unit = {}
}

protected override def design: Design = {
Expand All @@ -55,6 +62,10 @@ class HttpClientLoggingFilterTest extends AirSpec {
val m = RPCMethod("/rpc_method", "demo.RPCClass", "hello", Surface.of[Map[String, Any]], Surface.of[String])
client.rpc[Map[String, Any], String](m, Map("message" -> "world"))
}

test("switch dest") {
client.send(Http.GET("/").withDest(ServerAddress("localhost:8081")))
}
}

}
16 changes: 13 additions & 3 deletions airframe-http/src/main/scala/wvlet/airframe/http/HttpMessage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ object HttpMessage {
uri: String = "/",
header: HttpMultiMap = HttpMultiMap.empty,
message: Message = EmptyMessage,
// [optional] Destination address for sending the request. HttpChannel implementation should use this address
dest: Option[ServerAddress] = None,
// Remote address of the HTTP server, which is used for server-side logging purpose
remoteAddress: Option[ServerAddress] = None
) extends HttpMessage[Request] {
override def toString: String = s"Request(${method},${uri},${header})"
Expand All @@ -230,9 +233,16 @@ object HttpMessage {
*/
def query: HttpMultiMap = extractQueryFromUri(uri)

def withFilter(f: Request => Request): Request = f(this)
def withMethod(method: String): Request = this.copy(method = method)
def withUri(uri: String): Request = this.copy(uri = uri)
def withFilter(f: Request => Request): Request = f(this)
def withMethod(method: String): Request = this.copy(method = method)
def withUri(uri: String): Request = this.copy(uri = uri)

/**
* Overwrite the default destination address of the request
* @param dest
* @return
*/
def withDest(dest: ServerAddress): Request = this.copy(dest = Some(dest))
def withRemoteAddress(remoteAddress: ServerAddress): Request = this.copy(remoteAddress = Some(remoteAddress))

override protected def copyWith(newHeader: HttpMultiMap): Request = this.copy(header = newHeader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trait AsyncClient extends AsyncClientCompat with HttpClientFactory[AsyncClient]
protected def channel: HttpChannel
def config: HttpClientConfig

private val httpLogger: HttpLogger = config.newHttpLogger
private val httpLogger: HttpLogger = config.newHttpLogger(channel.destination)
private val loggingFilter: HttpClientFilter = config.newLoggingFilter(httpLogger)
private val circuitBreaker: CircuitBreaker = config.circuitBreaker

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package wvlet.airframe.http.client

import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.ServerAddress
import wvlet.airframe.rx.Rx

import scala.concurrent.duration.Duration
Expand All @@ -33,11 +34,25 @@ trait HttpChannelConfig {
trait HttpChannel extends AutoCloseable {

/**
* Send the request without modification.
* The default destination address to send requests
* @return
*/
def destination: ServerAddress

/**
* Send the request as is to the destination
* @param req
* @param channelConfig
* @return
*/
def send(req: Request, channelConfig: HttpChannelConfig): Response

/**
* Send an async request as is to the destination. Until the returned Rx is evaluated (e.g., by calling Rx.run), the
* request is not sent.
* @param req
* @param channelConfig
* @return
*/
def sendAsync(req: Request, channelConfig: HttpChannelConfig): Rx[Response]
}
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,15 @@ case class HttpClientConfig(
this.copy(httpLoggerProvider = loggerProvider)
}

def newHttpLogger: HttpLogger = {
httpLoggerProvider(httpLoggerConfig.addExtraTags(ListMap("client_name" -> name)))
def newHttpLogger(dest: ServerAddress): HttpLogger = {
httpLoggerProvider(
httpLoggerConfig.addExtraTags(
ListMap(
"client_name" -> name,
"dest" -> dest.hostAndPort
)
)
)
}

def newLoggingFilter(logger: HttpLogger): HttpClientFilter = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ trait SyncClient extends SyncClientCompat with HttpClientFactory[SyncClient] wit
protected def channel: HttpChannel
def config: HttpClientConfig

private val clientLogger: HttpLogger = config.newHttpLogger
private val clientLogger: HttpLogger = config.newHttpLogger(channel.destination)
private val loggingFilter: HttpClientFilter = config.newLoggingFilter(clientLogger)
private val circuitBreaker: CircuitBreaker = config.circuitBreaker

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ object HttpLogs extends LogSupport {
if (queryString.nonEmpty) {
m += "query_string" -> queryString
}
request.dest.foreach { d =>
m += "dest" -> d.hostAndPort
}
request.remoteAddress.foreach { remoteAddr =>
m += "remote_address" -> remoteAddr.hostAndPort
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@ object HttpClientFilterTest extends AirSpec {
private class DummyHttpChannel(reply: PartialFunction[Request, Response]) extends HttpChannel {
private val requestCount = new AtomicInteger(0)

override def destination: ServerAddress = ServerAddress("localhost:8080")
override def send(req: HttpMessage.Request, channelConfig: HttpChannelConfig): Response = {
if (reply.isDefinedAt(req)) {
reply(req)
} else {
throw RPCStatus.NOT_FOUND_U5.newException(s"RPC method not found: ${req.path}")
}
}
override def sendAsync(req: HttpMessage.Request, channelConfig: HttpChannelConfig): Rx[Response] = {
override def sendAsync(
req: HttpMessage.Request,
channelConfig: HttpChannelConfig
): Rx[Response] = {
Rx.single(send(req, channelConfig))
}
override def close(): Unit = {}
Expand Down

0 comments on commit edc39e7

Please sign in to comment.