Skip to content

Commit

Permalink
#0: work around segfault in main
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 8487461 commit 768b956
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions models/demos/llama3/tests/test_llama_model_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se
# Use instruct weights instead of general weights
instruct = True

model_args = TtModelArgs(mesh_device, instruct=instruct)
model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1)
tokenizer = Tokenizer(model_args.tokenizer_path)

logger.info("Loading weights...")
Expand All @@ -84,12 +84,12 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se
prompt_file = os.path.join(current_file_dir, "tale-of-two-cities.txt.bz2")

with bz2.open(prompt_file, "rt", encoding="utf-8") as f:
prompts = f.read()
prompt = f.read()

if instruct:
encoded_prompts = [encode_prompt_llama_instruct(tokenizer, prompt) for prompt in prompts]
encoded_prompt = encode_prompt_llama_instruct(tokenizer, prompt)[:seq_len]
else:
encoded_prompts = tokenizer.encode(prompts, bos=True, eos=False)[:seq_len]
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False)[:seq_len]

if run_ref_pt:
reference_model = Transformer(model_args)
Expand Down Expand Up @@ -126,9 +126,9 @@ def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_se

batch = 1

# Select the first token from the prompts for initial decoding
encoded_prompts_tensor = torch.tensor(encoded_prompts) # [:,0]
pt_decode_input = embd(encoded_prompts_tensor).view(batch, seq_len, -1)
# Select the first token from the prompt for initial decoding
encoded_prompt_tensor = torch.tensor(encoded_prompt) # [:,0]
pt_decode_input = embd(encoded_prompt_tensor).view(batch, seq_len, -1)

tt_decode_input = pt_decode_input

Expand Down

0 comments on commit 768b956

Please sign in to comment.