From 19ec8676df8c99b9c520d81bd3ffe439d1f6dd2d Mon Sep 17 00:00:00 2001 From: Salar Hosseini Date: Wed, 16 Oct 2024 16:39:09 +0000 Subject: [PATCH] #0: Modify vllm llama70b initialization to allow larger seq lens and different batch sizes Signed-off-by: Salar Hosseini --- models/demos/t3000/llama2_70b/tt/llama_generation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/models/demos/t3000/llama2_70b/tt/llama_generation.py b/models/demos/t3000/llama2_70b/tt/llama_generation.py index 003973874bad..91a8c4925436 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_generation.py +++ b/models/demos/t3000/llama2_70b/tt/llama_generation.py @@ -62,15 +62,15 @@ def __init__(self, configuration, state_dict, model_args, tt_args, paged_attenti del state_dict @classmethod - def initialize_vllm_model(cls, hf_config, t3k_mesh_device): + def initialize_vllm_model(cls, hf_config, t3k_mesh_device, max_batch_size): # TODO: pass in model args and tt args as parameters from vllm @dataclass class ModelArgs: llama_version: str = None ckpt_dir: str = None - max_batch_size: int = 32 + max_batch_size: int = 32 # overwritten by max_num_seqs from vllm num_layers: int = 80 - max_kv_context_len: int = 4096 + max_kv_context_len: int = 131072 @dataclass class TTArgs: @@ -85,7 +85,7 @@ class TTArgs: check_mesh_device(t3k_mesh_device, model_config) # initialize arg classes - model_args = ModelArgs(llama_version=llama_version, ckpt_dir=ckpt_dir) + model_args = ModelArgs(llama_version=llama_version, ckpt_dir=ckpt_dir, max_batch_size=max_batch_size) tt_args = TTArgs(mesh_device=t3k_mesh_device, cache_path=cache_path) # load state dict @@ -108,6 +108,10 @@ class TTArgs: configuration=configuration, state_dict=state_dict, model_args=model_args, tt_args=tt_args, vllm=True ) + @property + def cache_path(self): + return self.tt_model.cache_path + def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None, prompt_lens=None): _, seq_len = tokens.shape if seq_len == 1: