Skip to content

Commit

Permalink
feat: add http/2 support (#568)
Browse files Browse the repository at this point in the history
* feat: add http/2 support

Signed-off-by: Keming <kemingy94@gmail.com>

* update readme

Signed-off-by: Keming <kemingy94@gmail.com>

* Update README.md

Signed-off-by: zclzc <38581401+lkevinzc@users.noreply.github.com>

* fix sd1.5 link, add http/2 test

Signed-off-by: Keming <kemingy94@gmail.com>

---------

Signed-off-by: Keming <kemingy94@gmail.com>
Signed-off-by: zclzc <38581401+lkevinzc@users.noreply.github.com>
Co-authored-by: zclzc <38581401+lkevinzc@users.noreply.github.com>
  • Loading branch information
kemingy and lkevinzc authored Sep 22, 2024
1 parent 4ce3e1a commit f0ea3b5
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 90 deletions.
180 changes: 107 additions & 73 deletions Cargo.lock

Large diffs are not rendered by default.

6 changes: 2 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mosec"
version = "0.8.7"
version = "0.8.8"
authors = ["Keming <kemingy94@gmail.com>", "Zichen <lkevinzc@gmail.com>"]
edition = "2021"
license = "Apache-2.0"
Expand All @@ -10,9 +10,7 @@ description = "Model Serving made Efficient in the Cloud."
documentation = "https://docs.rs/mosec"
exclude = ["target", "examples", "tests", "scripts"]

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
hyper = { version = "1", features = ["http1", "server"] }
bytes = "1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["local-time", "json"] }
Expand All @@ -21,7 +19,7 @@ derive_more = { version = "1", features = ["display", "error", "from"] }
# MPMS that only one consumer sees each message & async
async-channel = "2.2"
prometheus-client = "0.22"
axum = "0.7"
axum = { version = "0.7", default-features = false, features = ["matched-path", "original-uri", "query", "tokio", "http1", "http2"]}
async-stream = "0.3.5"
serde = "1.0"
serde_json = "1.0"
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ from mosec.mixin import MsgpackMixin
logger = get_logger()
```

Then, we **build an API** for clients to query a text prompt and obtain an image based on the [stable-diffusion-v1-5 model](https://huggingface.co/runwayml/stable-diffusion-v1-5) in just 3 steps.
Then, we **build an API** for clients to query a text prompt and obtain an image based on the [stable-diffusion-v1-5 model](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) in just 3 steps.

1) Define your service as a class which inherits `mosec.Worker`. Here we also inherit `MsgpackMixin` to employ the [msgpack](https://msgpack.org/index.html) serialization format<sup>(a)</sup></a>.

Expand All @@ -104,10 +104,9 @@ Then, we **build an API** for clients to query a text prompt and obtain an image
class StableDiffusion(MsgpackMixin, Worker):
def __init__(self):
self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
"sd-legacy/stable-diffusion-v1-5", torch_dtype=torch.float16
)
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = self.pipe.to(device)
self.pipe.enable_model_cpu_offload()
self.example = ["useless example prompt"] * 4 # warmup (batch_size=4)

def forward(self, data: List[str]) -> List[memoryview]:
Expand Down Expand Up @@ -229,6 +228,7 @@ More ready-to-use examples can be found in the [Example](https://mosecorg.github
- For multi-stage services, note that the data passing through different stages will be serialized/deserialized by the `serialize_ipc/deserialize_ipc` methods, so extremely large data might make the whole pipeline slow. The serialized data is passed to the next stage through rust by default, you could enable shared memory to potentially reduce the latency (ref [RedisShmIPCMixin](https://mosecorg.github.io/mosec/examples/ipc.html#redis-shm-ipc-py)).
- You should choose appropriate `serialize/deserialize` methods, which are used to decode the user request and encode the response. By default, both are using JSON. However, images and embeddings are not well supported by JSON. You can choose msgpack which is faster and binary compatible (ref [Stable Diffusion](https://mosecorg.github.io/mosec/examples/stable_diffusion.html)).
- Configure the threads for OpenBLAS or MKL. It might not be able to choose the most suitable CPUs used by the current Python process. You can configure it for each worker by using the [env](https://mosecorg.github.io/mosec/reference/interface.html#mosec.server.Server.append_worker) (ref [custom GPU allocation](https://mosecorg.github.io/mosec/examples/env.html)).
- Enable HTTP/2 from client side. `mosec` automatically adapts to user's protocol (e.g., HTTP/2) since v0.8.8.

## Adopters

Expand Down
6 changes: 3 additions & 3 deletions examples/stable_diffusion/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
class StableDiffusion(MsgpackMixin, Worker):
def __init__(self):
self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
"sd-legacy/stable-diffusion-v1-5",
torch_dtype=torch.float16,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
self.pipe = self.pipe.to(device) # type: ignore
self.pipe.enable_model_cpu_offload()
self.example = ["useless example prompt"] * 4 # warmup (bs=4)

def forward(self, data: List[str]) -> List[memoryview]:
Expand Down
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ mypy~=1.11
pyright~=1.1
ruff~=0.6
pre-commit>=2.15.0
httpx==0.27.2
httpx[http2]==0.27.2
httpx-sse==0.4.0
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn main() {
.with(
output
.with_filter(filter::filter_fn(|metadata| {
!metadata.target().starts_with("hyper")
!metadata.target().starts_with("h2")
}))
.with_filter(filter::LevelFilter::DEBUG),
)
Expand Down
5 changes: 2 additions & 3 deletions src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
use std::time::Duration;

use axum::body::{to_bytes, Body};
use axum::http::Uri;
use axum::http::header::{HeaderValue, CONTENT_TYPE};
use axum::http::{Request, Response, StatusCode, Uri};
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::response::IntoResponse;
use bytes::Bytes;
use hyper::header::{HeaderValue, CONTENT_TYPE};
use hyper::{Request, Response, StatusCode};
use prometheus_client::encoding::text::encode;
use tracing::warn;
use utoipa::OpenApi;
Expand Down
2 changes: 1 addition & 1 deletion src/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};

use axum::http::StatusCode;
use bytes::Bytes;
use hyper::StatusCode;
use tokio::sync::{mpsc, oneshot, Barrier};
use tokio::time;
use tracing::{debug, error, info, warn};
Expand Down
22 changes: 22 additions & 0 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ def http_client():
yield client


@pytest.fixture
def http2_client():
# force to use HTTP/2
with httpx.Client(
base_url=f"http://127.0.0.1:{TEST_PORT}", http1=False, http2=True
) as client:
yield client


@pytest.fixture(scope="session")
def mosec_service(request):
params = request.param.split(" ")
Expand All @@ -54,6 +63,19 @@ def mosec_service(request):
assert wait_for_port_free(port=TEST_PORT), "service failed to stop"


@pytest.mark.parametrize(
"mosec_service, http2_client",
[
pytest.param("square_service", "", id="HTTP/2"),
],
indirect=["mosec_service", "http2_client"],
)
def test_http2_service(mosec_service, http2_client):
resp = http2_client.get("/")
assert resp.status_code == HTTPStatus.OK
assert resp.http_version == "HTTP/2"


@pytest.mark.parametrize(
"mosec_service, http_client",
[
Expand Down

0 comments on commit f0ea3b5

Please sign in to comment.