diff --git a/pyproject.toml b/pyproject.toml index d28e8b98f..71bcc51dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.54rc2" +version = "0.9.54rc3" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/trtllm-briton/src/extension.py b/truss/templates/trtllm-briton/src/extension.py index bcbe29915..90ff49689 100644 --- a/truss/templates/trtllm-briton/src/extension.py +++ b/truss/templates/trtllm-briton/src/extension.py @@ -1,7 +1,7 @@ from briton.spec_dec_truss_model import Model as SpecDecModel from briton.truss_model import Model -from truss.base.trt_llm_config import TRTLLMSpeculativeDecodingConfiguration -from truss.base.truss_config import TrussConfig + +TRTLLM_SPEC_DEC_TARGET_MODEL_NAME = "target" # TODO(pankaj) Define an ABC base class for this. That baseclass should live in # a new, smaller truss sub-library, perhaps called `truss-runtime`` for inclusion @@ -36,8 +36,8 @@ class Extension: """ def __init__(self, *args, **kwargs): - self._config = TrussConfig(**kwargs["config"]) - if isinstance(self._config.trt_llm, TRTLLMSpeculativeDecodingConfiguration): + self._config = kwargs["config"] + if TRTLLM_SPEC_DEC_TARGET_MODEL_NAME not in self._config.get("trt_llm"): self._model = SpecDecModel(*args, **kwargs) else: self._model = Model(*args, **kwargs)