Skip to content

Commit

Permalink
#0: Modify vllm llama70b initialization to allow larger seq lens and …
Browse files Browse the repository at this point in the history
…different batch sizes

Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Oct 16, 2024
1 parent aee03c7 commit 19ec867
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions models/demos/t3000/llama2_70b/tt/llama_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ def __init__(self, configuration, state_dict, model_args, tt_args, paged_attenti
del state_dict

@classmethod
def initialize_vllm_model(cls, hf_config, t3k_mesh_device):
def initialize_vllm_model(cls, hf_config, t3k_mesh_device, max_batch_size):
# TODO: pass in model args and tt args as parameters from vllm
@dataclass
class ModelArgs:
llama_version: str = None
ckpt_dir: str = None
max_batch_size: int = 32
max_batch_size: int = 32 # overwritten by max_num_seqs from vllm
num_layers: int = 80
max_kv_context_len: int = 4096
max_kv_context_len: int = 131072

@dataclass
class TTArgs:
Expand All @@ -85,7 +85,7 @@ class TTArgs:
check_mesh_device(t3k_mesh_device, model_config)

# initialize arg classes
model_args = ModelArgs(llama_version=llama_version, ckpt_dir=ckpt_dir)
model_args = ModelArgs(llama_version=llama_version, ckpt_dir=ckpt_dir, max_batch_size=max_batch_size)
tt_args = TTArgs(mesh_device=t3k_mesh_device, cache_path=cache_path)

# load state dict
Expand All @@ -108,6 +108,10 @@ class TTArgs:
configuration=configuration, state_dict=state_dict, model_args=model_args, tt_args=tt_args, vllm=True
)

@property
def cache_path(self):
return self.tt_model.cache_path

def forward(self, tokens: torch.Tensor, start_pos: int, page_table=None, kv_cache=None, prompt_lens=None):
_, seq_len = tokens.shape
if seq_len == 1:
Expand Down

0 comments on commit 19ec867

Please sign in to comment.