diff --git a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py index 66c058bd503a..63476bd4b26c 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_model_optimized.py @@ -158,7 +158,7 @@ def load_weights(self): ) self.norm_sharded = ttnn.to_device(norm_sharded_ttnn, self.mesh_device) - def validate_input_shape(self, inp_ids): + def validate_input_shape(self, inp_ids, mode): assert inp_ids.dim() == 2 batch, seq_len = inp_ids.shape assert ( @@ -168,6 +168,11 @@ def validate_input_shape(self, inp_ids): seq_len <= self.model_config["MAX_CONTEXT_LEN"] ), f"Sequence length {seq_len} exceeds MAX_CONTEXT_LEN {self.model_config['MAX_CONTEXT_LEN']}" + if mode == "prefill": + assert ( + seq_len < self.model_config["MAX_PREFILL_SEQ_LEN"] + ), f"Prefill only supports seq_len < {self.model_config['MAX_PREFILL_SEQ_LEN']}" + def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", page_table=None): """ Prepare inputs for decode mode. Assume that current token is at @@ -182,7 +187,7 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, mode="decode", rot_mats: [(1, 1, head_dim, head_dim)] * num_devices for decode [(1, 1, seq, head_dim), (1, 1, seq, head_dim)] * num_devices for prefill """ - self.validate_input_shape(inp_ids) + self.validate_input_shape(inp_ids, mode) batch, seq_len = inp_ids.shape cache_name = lambda name: self.cache_path / (f"{'llama3_' if self.llama3 else ''}{name}") diff --git a/models/demos/t3000/llama2_70b/tt/model_config.py b/models/demos/t3000/llama2_70b/tt/model_config.py index dfb3b100e47e..65396108b431 100644 --- a/models/demos/t3000/llama2_70b/tt/model_config.py +++ b/models/demos/t3000/llama2_70b/tt/model_config.py @@ -69,6 +69,7 @@ def get_model_config(llama_version="llama3", max_batch_size=32, max_context_len= "ALL_GATHER_NUM_LINKS": 1, "MAX_BATCH_SIZE": max_batch_size, "MAX_CONTEXT_LEN": max_context_len, + "MAX_PREFILL_SEQ_LEN": 32768, # TODO: remove once larger seq lens can be prefilled via decode in TtLlamaModelForGeneration "NUM_DEVICES": num_devices, "llama3-tg": MAX_SEQ_LEN_LLAMA3, "llama3.1-tg": MAX_SEQ_LEN_LLAMA3_1,