diff --git a/pyproject.toml b/pyproject.toml index bcbcb6333..41fc831b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.53" +version = "0.9.54rc5" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/base/constants.py b/truss/base/constants.py index 2ffc69518..a9a4a5017 100644 --- a/truss/base/constants.py +++ b/truss/base/constants.py @@ -105,9 +105,11 @@ REGISTRY_BUILD_SECRET_PREFIX = "DOCKER_REGISTRY_" -TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0_v0.0.17" +TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target" +TRTLLM_SPEC_DEC_DRAFT_MODEL_NAME = "draft" +TRTLLM_BASE_IMAGE = "baseten/briton-server:v0.13.0-4fd8a10-5e5c3d7" TRTLLM_PYTHON_EXECUTABLE = "/usr/bin/python3" -BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.9"] +BASE_TRTLLM_REQUIREMENTS = ["briton==0.3.10"] AUDIO_MODEL_TRTLLM_REQUIREMENTS = [ "--extra-index-url https://pypi.nvidia.com", "tensorrt_cu12_bindings==10.2.0.post1", diff --git a/truss/base/trt_llm_config.py b/truss/base/trt_llm_config.py index 315165402..e156692eb 100644 --- a/truss/base/trt_llm_config.py +++ b/truss/base/trt_llm_config.py @@ -55,6 +55,10 @@ class TrussTRTLLMBatchSchedulerPolicy(str, Enum): GUARANTEED_NO_EVICT = "guaranteed_no_evict" +class TrussSpecDecMode(str, Enum): + DRAFT_EXTERNAL: str = "DRAFT_TOKENS_EXTERNAL" + + class TrussTRTLLMBuildConfiguration(BaseModel): base_model: TrussTRTLLMModel max_seq_len: int @@ -73,13 +77,9 @@ class TrussTRTLLMBuildConfiguration(BaseModel): plugin_configuration: TrussTRTLLMPluginConfiguration = ( TrussTRTLLMPluginConfiguration() ) - kv_cache_free_gpu_mem_fraction: float = 0.9 num_builder_gpus: Optional[int] = None - enable_chunked_context: bool = False - batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = ( - TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT - ) - default_max_tokens: Optional[int] = None + speculative_decoding_mode: Optional[TrussSpecDecMode] = None + max_draft_len: Optional[int] = None @validator("max_beam_width") def check_max_beam_width(cls, v: int): @@ -91,40 +91,26 @@ def check_max_beam_width(cls, v: int): return v -class TrussTRTLLMServingConfiguration(BaseModel): - engine_repository: str - tokenizer_repository: str - tensor_parallel_count: int = 1 - pipeline_parallel_count: int = 1 +class TrussTRTLLMRuntimeConfiguration(BaseModel): + kv_cache_free_gpu_mem_fraction: float = 0.9 + enable_chunked_context: bool = False + num_draft_tokens: Optional[int] = None + batch_scheduler_policy: TrussTRTLLMBatchSchedulerPolicy = ( + TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT + ) + request_default_max_tokens: Optional[int] = None class TRTLLMConfiguration(BaseModel): - serve: Optional[TrussTRTLLMServingConfiguration] = None - build: Optional[TrussTRTLLMBuildConfiguration] = None + runtime: TrussTRTLLMRuntimeConfiguration = TrussTRTLLMRuntimeConfiguration() + build: TrussTRTLLMBuildConfiguration def __init__(self, **data): super().__init__(**data) - self._validate_minimum_required_configuration() self._validate_kv_cache_flags() if self.build.checkpoint_repository.source == CheckpointSource.HF: self._validate_hf_repo_id() - # In pydantic v2 this would be `@model_validator(mode="after")` and - # the __init__ override can be removed. - def _validate_minimum_required_configuration(self): - if not self.serve and not self.build: - raise ValueError("Either serve or build configurations must be provided") - if self.serve and self.build: - raise ValueError("Both serve and build configurations cannot be provided") - if self.serve is not None: - if (self.serve.engine_repository is None) ^ ( - self.serve.tokenizer_repository is None - ): - raise ValueError( - "Both engine_repository and tokenizer_repository must be provided" - ) - return self - def _validate_kv_cache_flags(self): if self.build is None: return self @@ -160,3 +146,41 @@ def requires_build(self): # when pydantic v2 is used here def to_json_dict(self, verbose=True): return json.loads(self.json(exclude_unset=not verbose)) + + +class TRTLLMSpeculativeDecodingConfiguration(BaseModel): + target: TRTLLMConfiguration + draft: TRTLLMConfiguration + total_token_limit: int = 500000 + + def __init__(self, **data): + super().__init__(**data) + self._spec_dec_configs = [ + self.target.build.speculative_decoding_mode, + self.target.build.max_draft_len, + ] + ( + [self.draft.runtime.num_draft_tokens] + if self.draft.runtime and self.draft.runtime.num_draft_tokens + else [False] + ) + self._validate_spec_dec() + + def _validate_spec_dec(self): + if any(self._spec_dec_configs): + if not all(self._spec_dec_configs): + raise ValueError( + "Speculative decoding requires all of `target.build.speculative_decoding_mode`, `target.build.max_draft_len`, and `draft.runtime.num_draft_tokens` to be configured." + ) + for trt_llm_config in [self.target, self.draft]: + if trt_llm_config.build.base_model is TrussTRTLLMModel.WHISPER: + raise ValueError("Speculative decoding for Whisper is not supported.") + if ( + self.target.build.tensor_parallel_count + != self.draft.build.tensor_parallel_count + ): + raise ValueError( + "Speculative decoding requires the same tensor parallelism for target and draft models." + ) + + def to_json_dict(self, verbose=True): + return json.loads(self.json(exclude_unset=not verbose)) diff --git a/truss/base/truss_config.py b/truss/base/truss_config.py index 21eac56f9..2a002c18d 100644 --- a/truss/base/truss_config.py +++ b/truss/base/truss_config.py @@ -3,14 +3,21 @@ from dataclasses import _MISSING_TYPE, dataclass, field, fields from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, Union import yaml -from truss.base.constants import HTTP_PUBLIC_BLOB_BACKEND +from truss.base.constants import ( + HTTP_PUBLIC_BLOB_BACKEND, + TRTLLM_SPEC_DEC_TARGET_MODEL_NAME, +) from truss.base.custom_types import ModelFrameworkType from truss.base.errors import ValidationError -from truss.base.trt_llm_config import TRTLLMConfiguration, TrussTRTLLMQuantizationType +from truss.base.trt_llm_config import ( + TRTLLMConfiguration, + TRTLLMSpeculativeDecodingConfiguration, + TrussTRTLLMQuantizationType, +) from truss.base.validation import ( validate_cpu_spec, validate_memory_spec, @@ -558,7 +565,9 @@ class TrussConfig: base_image: Optional[BaseImage] = None docker_server: Optional[DockerServer] = None model_cache: ModelCache = field(default_factory=ModelCache) - trt_llm: Optional[TRTLLMConfiguration] = None + trt_llm: Optional[ + Union[TRTLLMConfiguration, TRTLLMSpeculativeDecodingConfiguration] + ] = None build_commands: List[str] = field(default_factory=list) use_local_chains_src: bool = False @@ -571,6 +580,14 @@ def canonical_python_version(self) -> str: "py38": "3.8", }[self.python_version] + @property + def parsed_trt_llm_configs(self) -> List[TRTLLMConfiguration]: + if self.trt_llm: + if isinstance(self.trt_llm, TRTLLMSpeculativeDecodingConfiguration): + return [self.trt_llm.target, self.trt_llm.draft] + return [self.trt_llm] + return [] + @staticmethod def from_dict(d): config = TrussConfig( @@ -617,7 +634,10 @@ def from_dict(d): ModelCache.from_list, ), trt_llm=transform_optional( - d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x) + d.get("trt_llm"), + lambda x: (TRTLLMConfiguration(**x)) + if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME not in d.get("trt_llm") + else (TRTLLMSpeculativeDecodingConfiguration(**x)), ), build_commands=d.get("build_commands", []), use_local_chains_src=d.get("use_local_chains_src", False), @@ -670,17 +690,17 @@ def to_dict(self, verbose: bool = True): def clone(self): return TrussConfig.from_dict(self.to_dict()) - def _validate_accelerator_for_trt_llm_builder(self) -> None: - if self.trt_llm and self.trt_llm.build: + def _validate_trt_llm_config(self) -> None: + for trt_llm_config in self.parsed_trt_llm_configs: if ( - self.trt_llm.build.quantization_type + trt_llm_config.build.quantization_type is TrussTRTLLMQuantizationType.WEIGHTS_ONLY_INT8 and self.resources.accelerator.accelerator is Accelerator.A100 ): raise ValueError( "Weight only int8 quantization on A100 accelerators is not currently supported" ) - elif self.trt_llm.build.quantization_type in [ + elif trt_llm_config.build.quantization_type in [ TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV, ] and self.resources.accelerator.accelerator not in [ @@ -691,7 +711,7 @@ def _validate_accelerator_for_trt_llm_builder(self) -> None: raise ValueError( "FP8 quantization is only supported on L4 and H100 accelerators" ) - tensor_parallel_count = self.trt_llm.build.tensor_parallel_count + tensor_parallel_count = trt_llm_config.build.tensor_parallel_count if tensor_parallel_count != self.resources.accelerator.count: raise ValueError( @@ -720,7 +740,7 @@ def validate(self): raise ValueError( "Please ensure that only one of `requirements` and `requirements_file` is specified" ) - self._validate_accelerator_for_trt_llm_builder() + self._validate_trt_llm_config() def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]: @@ -796,6 +816,10 @@ def obj_to_dict(obj, verbose: bool = False): d["trt_llm"] = transform_optional( field_curr_value, lambda data: data.to_json_dict(verbose=verbose) ) + elif isinstance(field_curr_value, TRTLLMSpeculativeDecodingConfiguration): + d["trt_llm"] = transform_optional( + field_curr_value, lambda data: data.to_json_dict(verbose=verbose) + ) elif isinstance(field_curr_value, BaseImage): d["base_image"] = transform_optional( field_curr_value, lambda data: data.to_dict() diff --git a/truss/cli/cli.py b/truss/cli/cli.py index e0cb9a088..65fda2cc6 100644 --- a/truss/cli/cli.py +++ b/truss/cli/cli.py @@ -44,8 +44,8 @@ from truss.remote.baseten.utils.status import get_displayable_status from truss.remote.remote_factory import USER_TRUSSRC_PATH, RemoteFactory from truss.trt_llm.config_checks import ( - check_and_update_memory_for_trt_llm_builder, - check_secrets_for_trt_llm_builder, + is_missing_secrets_for_trt_llm_builder, + memory_updated_for_trt_llm_builder, uses_trt_llm_builder, ) from truss.truss_handle.build import cleanup as _cleanup @@ -1150,32 +1150,32 @@ def push( live_reload_disabled_text = "Development mode is currently not supported for trusses using TRT-LLM build flow, push as a published model using --publish" console.print(live_reload_disabled_text, style="red") sys.exit(1) - if not check_secrets_for_trt_llm_builder(tr): + if is_missing_secrets_for_trt_llm_builder(tr): missing_token_text = ( "`hf_access_token` must be provided in secrets to build a gated model. " "Please see https://docs.baseten.co/deploy/guides/private-model for configuration instructions." ) console.print(missing_token_text, style="red") sys.exit(1) - if not check_and_update_memory_for_trt_llm_builder(tr): + if memory_updated_for_trt_llm_builder(tr): console.print( f"Automatically increasing memory for trt-llm builder to {TRTLLM_MIN_MEMORY_REQUEST_GI}Gi." ) - config = tr.spec.config - if ( - config.trt_llm.build.quantization_type - in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV] - and not config.trt_llm.build.num_builder_gpus - ): - fp8_and_num_builder_gpus_text = ( - "Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. " - "GPU memory required at build time may be significantly more than that required at inference time due to FP8 quantization, which can result in OOM failures during the engine build phase." - "`num_builder_gpus` can be used to specify the number of GPUs to use at build time." - ) - console.print( - fp8_and_num_builder_gpus_text, - style="yellow", - ) + for trt_llm_config in tr.spec.config.parsed_trt_llm_configs: + if ( + trt_llm_config.build.quantization_type + in [TrussTRTLLMQuantizationType.FP8, TrussTRTLLMQuantizationType.FP8_KV] + and not trt_llm_config.build.num_builder_gpus + ): + fp8_and_num_builder_gpus_text = ( + "Warning: build specifies FP8 quantization but does not explicitly specify number of build GPUs. " + "GPU memory required at build time may be significantly more than that required at inference time due to FP8 quantization, which can result in OOM failures during the engine build phase." + "`num_builder_gpus` can be used to specify the number of GPUs to use at build time." + ) + console.print( + fp8_and_num_builder_gpus_text, + style="yellow", + ) # TODO(Abu): This needs to be refactored to be more generic service = remote_provider.push( diff --git a/truss/contexts/image_builder/serving_image_builder.py b/truss/contexts/image_builder/serving_image_builder.py index 19c209c38..0338736dd 100644 --- a/truss/contexts/image_builder/serving_image_builder.py +++ b/truss/contexts/image_builder/serving_image_builder.py @@ -41,7 +41,7 @@ TRUSSLESS_MAX_PAYLOAD_SIZE, USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME, ) -from truss.base.trt_llm_config import TrussTRTLLMModel +from truss.base.trt_llm_config import TRTLLMConfiguration, TrussTRTLLMModel from truss.base.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR, BaseImage, TrussConfig from truss.base.truss_spec import TrussSpec from truss.contexts.image_builder.cache_warmer import ( @@ -344,6 +344,59 @@ def __init__(self, truss_dir: Path) -> None: def default_tag(self): return f"{self._spec.model_framework_name}-model:latest" + def _copy_into_build_dir( + self, from_path: Path, build_dir: Path, path_in_build_dir: str + ): + copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator] + + def prepare_trtllm_build_dir(self, build_dir: Path): + config = self._spec.config + trt_llm_config = config.trt_llm + if not trt_llm_config: + return + is_audio_model = ( + trt_llm_config.build.base_model == TrussTRTLLMModel.WHISPER + if isinstance(trt_llm_config, TRTLLMConfiguration) + and trt_llm_config.build is not None + else False + ) + + if is_audio_model: + copy_tree_path(AUDIO_MODEL_TRTLLM_TRUSS_DIR, build_dir, ignore_patterns=[]) + else: + # trt_llm is treated as an extension at model run time. + self._copy_into_build_dir( + TRTLLM_TRUSS_DIR / "src", + build_dir, + f"{BUILD_SERVER_DIR_NAME}/{BUILD_SERVER_EXTENSIONS_PATH}/trt_llm", + ) + # TODO(pankaj) Do this differently. This is not ideal, user + # supplied code in bundled packages can conflict with those from + # the trtllm extension. We don't want to put this in the build + # directory directly either because of chances of conflict there + # as well and the noise it can create there. We need to find a + # new place that's made available in model's pythonpath. This is + # a bigger lift and feels overkill right now. Worth revisiting + # if we come across cases of actual conflicts. + self._copy_into_build_dir( + TRTLLM_TRUSS_DIR / DEFAULT_BUNDLED_PACKAGES_DIR, + build_dir, + DEFAULT_BUNDLED_PACKAGES_DIR, + ) + + config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY + + if not is_audio_model: + config.base_image = BaseImage( + image=TRTLLM_BASE_IMAGE, + python_executable_path=TRTLLM_PYTHON_EXECUTABLE, + ) + config.requirements.extend(BASE_TRTLLM_REQUIREMENTS) + else: + config.requirements.extend(AUDIO_MODEL_TRTLLM_REQUIREMENTS) + config.system_packages.extend(AUDIO_MODEL_TRTLLM_SYSTEM_PACKAGES) + config.python_version = "py310" + def prepare_image_build_dir( self, build_dir: Optional[Path] = None, use_hf_secret: bool = False ): @@ -358,8 +411,7 @@ def prepare_image_build_dir( # TODO(pankaj) We probably don't need model framework specific directory. build_dir = build_truss_target_directory(model_framework_name) - def copy_into_build_dir(from_path: Path, path_in_build_dir: str): - copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator] + data_dir = build_dir / config.data_dir # type: ignore[operator] truss_ignore_patterns = [] if (truss_dir / USER_TRUSS_IGNORE_FILE).exists(): @@ -371,8 +423,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): copy_tree_path(truss_dir, build_dir, ignore_patterns=truss_ignore_patterns) if config.docker_server is not None: - copy_into_build_dir( + self._copy_into_build_dir( TEMPLATES_DIR / "docker_server_requirements.txt", + build_dir, "docker_server_requirements.txt", ) @@ -380,52 +433,7 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): generate_docker_server_supervisord_config(build_dir, config) - # Copy over template truss for TRT-LLM (we overwrite the model and packages dir) - # Most of the code is pulled from upstream triton-inference-server tensorrtllm_backend - # https://github.com/triton-inference-server/tensorrtllm_backend/tree/v0.9.0/all_models/inflight_batcher_llm - if config.trt_llm is not None: - is_audio_model = ( - config.trt_llm.build.base_model == TrussTRTLLMModel.WHISPER - if config.trt_llm.build is not None - else False - ) - - if is_audio_model: - copy_tree_path( - AUDIO_MODEL_TRTLLM_TRUSS_DIR, build_dir, ignore_patterns=[] - ) - else: - # trt_llm is treated as an extension at model run time. - copy_into_build_dir( - TRTLLM_TRUSS_DIR / "src", - f"{BUILD_SERVER_DIR_NAME}/{BUILD_SERVER_EXTENSIONS_PATH}/trt_llm", - ) - # TODO(pankaj) Do this differently. This is not ideal, user - # supplied code in bundled packages can conflict with those from - # the trtllm extension. We don't want to put this in the build - # directory directly either because of chances of conflict there - # as well and the noise it can create there. We need to find a - # new place that's made available in model's pythonpath. This is - # a bigger lift and feels overkill right now. Worth revisiting - # if we come across cases of actual conflicts. - copy_into_build_dir( - TRTLLM_TRUSS_DIR / DEFAULT_BUNDLED_PACKAGES_DIR, - DEFAULT_BUNDLED_PACKAGES_DIR, - ) - - config.runtime.predict_concurrency = TRTLLM_PREDICT_CONCURRENCY - - if not is_audio_model: - config.base_image = BaseImage( - image=TRTLLM_BASE_IMAGE, - python_executable_path=TRTLLM_PYTHON_EXECUTABLE, - ) - - config.requirements.extend(BASE_TRTLLM_REQUIREMENTS) - else: - config.requirements.extend(AUDIO_MODEL_TRTLLM_REQUIREMENTS) - config.system_packages.extend(AUDIO_MODEL_TRTLLM_SYSTEM_PACKAGES) - config.python_version = "py310" + self.prepare_trtllm_build_dir(build_dir=build_dir) # Override config.yml with (build_dir / CONFIG_FILE).open("w") as config_file: @@ -448,30 +456,36 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): ) # Copy inference server code - copy_into_build_dir(SERVER_CODE_DIR, BUILD_SERVER_DIR_NAME) - copy_into_build_dir( + self._copy_into_build_dir(SERVER_CODE_DIR, build_dir, BUILD_SERVER_DIR_NAME) + self._copy_into_build_dir( SHARED_SERVING_AND_TRAINING_CODE_DIR, + build_dir, BUILD_SERVER_DIR_NAME + "/" + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, ) # Copy control server code if config.live_reload: - copy_into_build_dir(CONTROL_SERVER_CODE_DIR, BUILD_CONTROL_SERVER_DIR_NAME) - copy_into_build_dir( + self._copy_into_build_dir( + CONTROL_SERVER_CODE_DIR, build_dir, BUILD_CONTROL_SERVER_DIR_NAME + ) + self._copy_into_build_dir( SHARED_SERVING_AND_TRAINING_CODE_DIR, + build_dir, BUILD_CONTROL_SERVER_DIR_NAME + "/control/" + SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME, ) if config.use_local_chains_src: - copy_into_build_dir(CHAINS_CODE_DIR, BUILD_CHAINS_DIR_NAME) + self._copy_into_build_dir(CHAINS_CODE_DIR, build_dir, BUILD_CHAINS_DIR_NAME) # Copy base TrussServer requirements if supplied custom base image base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME if config.base_image: - copy_into_build_dir( - base_truss_server_reqs_filepath, BASE_SERVER_REQUIREMENTS_TXT_FILENAME + self._copy_into_build_dir( + base_truss_server_reqs_filepath, + build_dir, + BASE_SERVER_REQUIREMENTS_TXT_FILENAME, ) # Copy model framework specific requirements file @@ -480,7 +494,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): ) should_install_server_requirements = file_is_not_empty(server_reqs_filepath) if should_install_server_requirements: - copy_into_build_dir(server_reqs_filepath, SERVER_REQUIREMENTS_TXT_FILENAME) + self._copy_into_build_dir( + server_reqs_filepath, build_dir, SERVER_REQUIREMENTS_TXT_FILENAME + ) with open(base_truss_server_reqs_filepath, "r") as f: base_server_requirements = f.read() @@ -504,8 +520,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str): else base_server_requirements ) if spec.requirements_file is not None: - copy_into_build_dir( + self._copy_into_build_dir( truss_dir / spec.requirements_file, + build_dir, USER_SUPPLIED_REQUIREMENTS_TXT_FILENAME, ) (build_dir / REQUIREMENTS_TXT_FILENAME).write_text( diff --git a/truss/templates/trtllm-briton/src/extension.py b/truss/templates/trtllm-briton/src/extension.py index d3d1fbc49..c53c105d9 100644 --- a/truss/templates/trtllm-briton/src/extension.py +++ b/truss/templates/trtllm-briton/src/extension.py @@ -1,5 +1,8 @@ +from briton.spec_dec_truss_model import Model as SpecDecModel from briton.truss_model import Model +TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target" + # TODO(pankaj) Define an ABC base class for this. That baseclass should live in # a new, smaller truss sub-library, perhaps called `truss-runtime`` for inclusion # in Truss runtime. Once we have that sub-library, we should define the Extension @@ -33,7 +36,11 @@ class Extension: """ def __init__(self, *args, **kwargs): - self._model = Model(*args, **kwargs) + self._config = kwargs["config"] + if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME in self._config.get("trt_llm"): + self._model = SpecDecModel(*args, **kwargs) + else: + self._model = Model(*args, **kwargs) def model_override(self): """Return a model object. diff --git a/truss/tests/conftest.py b/truss/tests/conftest.py index 47fa212a1..9c3621940 100644 --- a/truss/tests/conftest.py +++ b/truss/tests/conftest.py @@ -12,6 +12,7 @@ import yaml from truss.base.custom_types import Example +from truss.base.trt_llm_config import TrussTRTLLMBatchSchedulerPolicy from truss.base.truss_config import DEFAULT_BUNDLED_PACKAGES_DIR from truss.contexts.image_builder.serving_image_builder import ( ServingImageBuilderContext, @@ -400,7 +401,13 @@ def modify_handle(h: TrussHandle): "source": "HF", "repo": "meta/llama4-500B", }, - } + }, + "runtime": { + "kv_cache_free_gpu_mem_fraction": 0.9, + "enabled_chunked_context": False, + "num_draft_tokens": None, + "batch_scheduler_policy": TrussTRTLLMBatchSchedulerPolicy.GUARANTEED_NO_EVICT.value, + }, } content["resources"]["accelerator"] = "H100:1" diff --git a/truss/tests/test_config.py b/truss/tests/test_config.py index 1fcbfced5..c1401e905 100644 --- a/truss/tests/test_config.py +++ b/truss/tests/test_config.py @@ -1,3 +1,4 @@ +import copy import tempfile from contextlib import nullcontext as does_not_raise from pathlib import Path @@ -7,7 +8,11 @@ import yaml from truss.base.custom_types import ModelFrameworkType -from truss.base.trt_llm_config import TrussTRTLLMQuantizationType +from truss.base.trt_llm_config import ( + TRTLLMSpeculativeDecodingConfiguration, + TrussSpecDecMode, + TrussTRTLLMQuantizationType, +) from truss.base.truss_config import ( DEFAULT_CPU, DEFAULT_MEMORY, @@ -65,11 +70,46 @@ def trtllm_config(default_config) -> Dict[str, Any]: "repo": "meta/llama4-500B", }, "gather_all_token_logits": False, - } + }, + "runtime": {}, } return trtllm_config +@pytest.fixture +def trtllm_spec_dec_config(trtllm_config) -> Dict[str, Any]: + spec_dec_config = copy.deepcopy(trtllm_config) + spec_dec_config["trt_llm"] = { + "target": { + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + "gather_all_token_logits": False, + "speculative_decoding_mode": TrussSpecDecMode.DRAFT_EXTERNAL, + "max_draft_len": 10, + }, + }, + "draft": { + "build": { + "base_model": "llama", + "max_seq_len": 2048, + "max_batch_size": 512, + "checkpoint_repository": { + "source": "HF", + "repo": "meta/llama4-500B", + }, + }, + "runtime": {"num_draft_tokens": 4}, + }, + } + return spec_dec_config + + @pytest.mark.parametrize( "input_dict, expect_resources, output_dict", [ @@ -509,10 +549,45 @@ def test_plugin_paged_fp8_context_fmha_check(trtllm_config): @pytest.mark.parametrize("verbose, expect_equal", [(False, True), (True, False)]) -def test_to_dict_trtllm(verbose, expect_equal, trtllm_config): +def test_to_dict_trtllm(verbose, expect_equal, trtllm_config, trtllm_spec_dec_config): assert ( TrussConfig.from_dict(trtllm_config).to_dict(verbose=verbose) == trtllm_config ) == expect_equal + assert ( + TrussConfig.from_dict(trtllm_spec_dec_config).to_dict(verbose=verbose) + == trtllm_spec_dec_config + ) == expect_equal + + +@pytest.mark.parametrize("should_raise", [False, True]) +def test_from_dict_spec_dec_trt_llm(should_raise, trtllm_spec_dec_config): + test_config = copy.deepcopy(trtllm_spec_dec_config) + if should_raise: + test_config["trt_llm"]["target"]["build"]["speculative_decoding_mode"] = None + with pytest.raises(ValueError): + TrussConfig.from_dict(test_config) + test_config["trt_llm"]["target"]["build"]["speculative_decoding_mode"] = ( + trtllm_spec_dec_config[ + "trt_llm" + ]["target"]["build"]["speculative_decoding_mode"] + ) + test_config["trt_llm"]["draft"]["runtime"]["num_draft_tokens"] = None + with pytest.raises(ValueError): + TrussConfig.from_dict(test_config) + else: + TrussConfig.from_dict(trtllm_spec_dec_config) + + +@pytest.mark.parametrize("spec_dec_enabled", [False, True]) +def test_trtllm_spec_dec(spec_dec_enabled, trtllm_config, trtllm_spec_dec_config): + config = trtllm_config + if spec_dec_enabled: + config = trtllm_spec_dec_config + truss_config = TrussConfig.from_dict(config) + assert ( + isinstance(truss_config.trt_llm, TRTLLMSpeculativeDecodingConfiguration) + == spec_dec_enabled + ) def test_from_yaml_invalid_requirements_configuration(): diff --git a/truss/tests/util/test_config_checks.py b/truss/tests/util/test_config_checks.py index 65154de60..cbcd6418b 100644 --- a/truss/tests/util/test_config_checks.py +++ b/truss/tests/util/test_config_checks.py @@ -3,8 +3,8 @@ import pytest from truss.base.constants import TRTLLM_MIN_MEMORY_REQUEST_GI from truss.trt_llm.config_checks import ( - check_and_update_memory_for_trt_llm_builder, - check_secrets_for_trt_llm_builder, + is_missing_secrets_for_trt_llm_builder, + memory_updated_for_trt_llm_builder, ) from truss.truss_handle.truss_handle import TrussHandle @@ -13,13 +13,13 @@ @pytest.mark.parametrize( "has_secret, is_model_public, expected_result", [ - (False, False, False), - (False, True, True), - (True, False, True), - (True, True, True), + (False, False, True), + (False, True, False), + (True, False, False), + (True, True, False), ], ) -def test_check_secrets_for_trt_llm_builder( +def test_is_missing_secrets_for_trt_llm_builder( _is_model_public_mock, has_secret, is_model_public, @@ -30,11 +30,11 @@ def test_check_secrets_for_trt_llm_builder( handle = TrussHandle(custom_model_trt_llm) if has_secret: handle.add_secret("hf_access_token") - assert check_secrets_for_trt_llm_builder(handle) == expected_result + assert is_missing_secrets_for_trt_llm_builder(handle) == expected_result def test_check_and_update_memory_for_trt_llm_builder(custom_model_trt_llm): handle = TrussHandle(custom_model_trt_llm) - assert not check_and_update_memory_for_trt_llm_builder(handle) + assert memory_updated_for_trt_llm_builder(handle) assert handle.spec.memory == f"{TRTLLM_MIN_MEMORY_REQUEST_GI}Gi" assert handle.spec.memory_in_bytes == TRTLLM_MIN_MEMORY_REQUEST_GI * 1024**3 diff --git a/truss/trt_llm/config_checks.py b/truss/trt_llm/config_checks.py index fc6964151..2562af6e1 100644 --- a/truss/trt_llm/config_checks.py +++ b/truss/trt_llm/config_checks.py @@ -8,26 +8,26 @@ from truss.truss_handle.truss_handle import TrussHandle -def check_secrets_for_trt_llm_builder(tr: TrussHandle) -> bool: - if tr.spec.config.trt_llm and tr.spec.config.trt_llm.build: - source = tr.spec.config.trt_llm.build.checkpoint_repository.source - hf_model_id = tr.spec.config.trt_llm.build.checkpoint_repository.repo +def is_missing_secrets_for_trt_llm_builder(tr: TrussHandle) -> bool: + for trt_llm_config in tr.spec.config.parsed_trt_llm_configs: + source = trt_llm_config.build.checkpoint_repository.source + hf_model_id = trt_llm_config.build.checkpoint_repository.repo if ( source == CheckpointSource.HF and HF_ACCESS_TOKEN_KEY not in tr.spec.secrets and not _is_model_public(hf_model_id) ): - return False - return True + return True + return False -def check_and_update_memory_for_trt_llm_builder(tr: TrussHandle) -> bool: +def memory_updated_for_trt_llm_builder(tr: TrussHandle) -> bool: if uses_trt_llm_builder(tr): if tr.spec.memory_in_bytes < TRTLLM_MIN_MEMORY_REQUEST_GI * 1024**3: tr.spec.config.resources.memory = f"{TRTLLM_MIN_MEMORY_REQUEST_GI}Gi" tr.spec.config.write_to_yaml_file(tr.spec.config_path, verbose=False) - return False - return True + return True + return False def _is_model_public(model_id: str) -> bool: @@ -40,6 +40,4 @@ def _is_model_public(model_id: str) -> bool: def uses_trt_llm_builder(tr: TrussHandle) -> bool: - return ( - tr.spec.config.trt_llm is not None and tr.spec.config.trt_llm.build is not None - ) + return tr.spec.config.trt_llm is not None