Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hackathon] truss metrics #1197

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.45rc009"
version = "0.9.45rc013"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
5 changes: 5 additions & 0 deletions truss/templates/control/control/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from helpers.inference_server_process_controller import InferenceServerProcessController
from helpers.inference_server_starter import async_inference_server_startup_flow
from helpers.truss_patch.model_container_patch_applier import ModelContainerPatchApplier
from prometheus_client import make_asgi_app
from shared.logging import setup_logging
from starlette.datastructures import State

Expand Down Expand Up @@ -103,6 +104,10 @@ async def start_background_inference_startup():
app.state = app_state
app.include_router(control_app)

# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)

@app.on_event("shutdown")
def on_shutdown():
# FastApi handles the term signal to start the shutdown flow. Here we
Expand Down
1 change: 1 addition & 0 deletions truss/templates/control/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ tenacity==8.1.0
httpx==0.27.0
python-json-logger==2.0.2
loguru==0.7.2
prometheus_client==0.15.0
1 change: 1 addition & 0 deletions truss/templates/server/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ requests==2.31.0
uvicorn==0.24.0
uvloop==0.19.0
aiofiles==24.1.0
prometheus-client==0.15.0
6 changes: 6 additions & 0 deletions truss/templates/server/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from opentelemetry import propagate as otel_propagate
from opentelemetry import trace
from opentelemetry.sdk import trace as sdk_trace
from prometheus_client import make_asgi_app
from shared import serialization, util
from shared.logging import setup_logging
from shared.secrets_resolver import SecretsResolver
Expand Down Expand Up @@ -342,6 +343,11 @@ def exit_self():
on_term=exit_self,
)
app.add_middleware(BaseHTTPMiddleware, dispatch=termination_handler_middleware)

# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)

return app

def start(self):
Expand Down
61 changes: 61 additions & 0 deletions truss/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Accelerator,
AcceleratorSpec,
BaseImage,
CustomMetricConfig,
DockerAuthSettings,
DockerAuthType,
ModelCache,
Expand Down Expand Up @@ -126,6 +127,66 @@ def test_parse_resources(input_dict, expect_resources, output_dict):
assert parsed_result.to_dict() == output_dict


@pytest.mark.parametrize(
"input_dict, expect_metrics, output_dict",
[
(
{
"name": "metric_name",
"display_name": "Metric Name",
"type": "histogram",
"unit": "ms",
},
CustomMetricConfig(
name="metric_name",
display_name="Metric Name",
type="histogram",
unit="ms",
),
{
"name": "metric_name",
"display_name": "Metric Name",
"type": "histogram",
"unit": "ms",
},
),
],
)
def test_parse_custom_metric(input_dict, expect_metrics, output_dict):
parsed_result = CustomMetricConfig.from_dict(input_dict)
assert parsed_result == expect_metrics
assert parsed_result.to_dict() == output_dict


def test_config_metrics(default_config):
default_config["metrics"] = [
{
"name": "metric_name",
"display_name": "Metric Name",
"type": "histogram",
"unit": "ms",
},
{
"name": "metric_name2",
"display_name": "Metric Name 2",
"type": "counter",
"unit": "count",
},
]
config = TrussConfig.from_dict(default_config)
assert config.metrics == [
CustomMetricConfig(
name="metric_name", display_name="Metric Name", type="histogram", unit="ms"
),
CustomMetricConfig(
name="metric_name2",
display_name="Metric Name 2",
type="counter",
unit="count",
),
]


@pytest.mark.parametrize(
"input_str, expected_acc",
[
Expand Down
25 changes: 25 additions & 0 deletions truss/tests/test_model_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,31 @@ def _test_invocations(expected_code):
_test_invocations(200)


@pytest.mark.integration
def test_metrics():
model = """
from fastapi.responses import Response
from prometheus_client import Counter

class Model:
def __init__(self):
self.counter = Counter('my_really_cool_metric', 'my really cool metric description')

def predict(self, model_input):
self.counter.inc(10)
return model_input
"""
config = "model_name: metrics-truss"
with ensure_kill_all(), temp_truss(model, config) as tr:
_ = tr.docker_run(local_port=8090, detach=True, wait_for_server_ready=True)
metrics_url = "http://localhost:8090/metrics"
requests.post(PREDICT_URL, json={})
resp = requests.get(metrics_url)
assert resp.status_code == 200
assert "my_really_cool_metric_total 10.0" in resp.text
assert "my_really_cool_metric_created" in resp.text


@pytest.mark.integration
def test_setup_environment():
# Test truss that uses setup_environment() without load()
Expand Down
35 changes: 34 additions & 1 deletion truss/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,31 @@ def to_dict(self):
}


@dataclass
class CustomMetricConfig:
name: str
display_name: str
type: str
unit: str

@staticmethod
def from_dict(d):
return CustomMetricConfig(
name=d.get("name"),
display_name=d.get("display_name"),
type=d.get("type"),
unit=d.get("unit"),
)

def to_dict(self) -> dict:
return {
"name": self.name,
"display_name": self.display_name,
"type": self.type,
"unit": self.unit,
}


@dataclass
class ExternalDataItem:
"""A piece of remote data, to be made available to the Truss at serving time.
Expand Down Expand Up @@ -546,6 +571,7 @@ class TrussConfig:
model_cache: ModelCache = field(default_factory=ModelCache)
trt_llm: Optional[TRTLLMConfiguration] = None
build_commands: List[str] = field(default_factory=list)
metrics: List[CustomMetricConfig] = field(default_factory=list)

@property
def canonical_python_version(self) -> str:
Expand Down Expand Up @@ -605,6 +631,9 @@ def from_dict(d):
d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x)
),
build_commands=d.get("build_commands", []),
metrics=transform_optional(
d.get("metrics") or [], lambda x: [CustomMetricConfig(**m) for m in x]
),
)
config.validate()
return config
Expand Down Expand Up @@ -780,7 +809,11 @@ def obj_to_dict(obj, verbose: bool = False):
d["docker_auth"] = transform_optional(
field_curr_value, lambda data: data.to_dict()
)
elif field_name == "metrics":
d["metrics"] = transform_optional(
field_curr_value,
lambda data: [metric.to_dict() for metric in data] if data else [],
)
else:
d[field_name] = field_curr_value

return d
Loading