Skip to content

Commit

Permalink
#0: Add assert for supported prefill seq lens in llama70b
Browse files Browse the repository at this point in the history
Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
  • Loading branch information
skhorasganiTT committed Oct 16, 2024
1 parent 19ec867 commit 0b2c29f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
9 changes: 7 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_model_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def load_weights(self):
)
self.norm_sharded = ttnn.to_device(norm_sharded_ttnn, self.mesh_device)

def validate_input_shape(self, inp_ids):
def validate_input_shape(self, inp_ids, mode):
assert inp_ids.dim() == 2
batch, seq_len = inp_ids.shape
assert (
Expand All @@ -168,6 +168,11 @@ def validate_input_shape(self, inp_ids):
seq_len <= self.model_config["MAX_CONTEXT_LEN"]
), f"Sequence length {seq_len} exceeds MAX_CONTEXT_LEN {self.model_config['MAX_CONTEXT_LEN']}"

if mode == "prefill":
assert (
seq_len < self.model_config["MAX_PREFILL_SEQ_LEN"]
), f"Prefill only supports seq_len < {self.model_config['MAX_PREFILL_SEQ_LEN']}"

def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", page_table=None):
"""
Prepare inputs for decode mode. Assume that current token is at
Expand All @@ -182,7 +187,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode",
rot_mats: [(1, 1, head_dim, head_dim)] * num_devices for decode
[(1, 1, seq, head_dim), (1, 1, seq, head_dim)] * num_devices for prefill
"""
self.validate_input_shape(inp_ids)
self.validate_input_shape(inp_ids, mode)
batch, seq_len = inp_ids.shape

cache_name = lambda name: self.cache_path / (f"{'llama3_' if self.llama3 else ''}{name}")
Expand Down
1 change: 1 addition & 0 deletions models/demos/t3000/llama2_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def get_model_config(llama_version="llama3", max_batch_size=32, max_context_len=
"ALL_GATHER_NUM_LINKS": 1,
"MAX_BATCH_SIZE": max_batch_size,
"MAX_CONTEXT_LEN": max_context_len,
"MAX_PREFILL_SEQ_LEN": 32768, # TODO: remove once larger seq lens can be prefilled via decode in TtLlamaModelForGeneration
"NUM_DEVICES": num_devices,
"llama3-tg": MAX_SEQ_LEN_LLAMA3,
"llama3.1-tg": MAX_SEQ_LEN_LLAMA3_1,
Expand Down

0 comments on commit 0b2c29f

Please sign in to comment.