diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 97e614e2e7da..ca48efd8b118 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -58,7 +58,7 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se # Use instruct weights instead of general weights instruct = True - model_args = TtModelArgs(mesh_device, instruct=instruct) + model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1) tokenizer = Tokenizer(model_args.tokenizer_path) logger.info("Loading weights...") @@ -84,12 +84,12 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se prompt_file = os.path.join(current_file_dir, "tale-of-two-cities.txt.bz2") with bz2.open(prompt_file, "rt", encoding="utf-8") as f: - prompts = f.read() + prompt = f.read() if instruct: - encoded_prompts = [encode_prompt_llama_instruct(tokenizer, prompt) for prompt in prompts] + encoded_prompt = encode_prompt_llama_instruct(tokenizer, prompt)[:seq_len] else: - encoded_prompts = tokenizer.encode(prompts, bos=True, eos=False)[:seq_len] + encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False)[:seq_len] if run_ref_pt: reference_model = Transformer(model_args) @@ -126,9 +126,9 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se batch = 1 - # Select the first token from the prompts for initial decoding - encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0] - pt_decode_input = embd(encoded_prompts_tensor).view(batch, seq_len, -1) + # Select the first token from the prompt for initial decoding + encoded_prompt_tensor = torch.tensor(encoded_prompt) # [:,0] + pt_decode_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1) tt_decode_input = pt_decode_input