Skip to content

Commit

Permalink
#0: Remove debug code to speed up demo
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Nov 21, 2024
1 parent a764cb8 commit 1454ae2
Showing 1 changed file with 0 additions and 47 deletions.
47 changes: 0 additions & 47 deletions models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,6 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)

decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok))
# TODO Miguel: I think the problem is here, not updating the get rot mats
# The problem is that the get_rot_mats is using embedding that ends up on the host.
rot_mats = rope_setup.get_rot_mats(rot_mat_idxs)
tt_out = tt_model(
decode_input,
Expand Down Expand Up @@ -532,51 +530,6 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_
)[0, 0, 0, :batch_size]
ttnn.record_event(1, write_event)

# TODO Miguel Remove
print("==== ITERATION", iteration, "====")
# Check input
input_torch = ttnn.to_torch(decode_input, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=3))
for i in range(batch_size):
input_equal = torch.eq(input_torch[:, :, 0, :], input_torch[:, :, i, :]).all()
if not input_equal:
print("Batch", i, "input not equal")

# Check output
for i in range(batch_size):
out_equal = torch.eq(tt_output_torch[0], tt_output_torch[i])
if not out_equal:
print("Batch", i, "output not equal")

# Check KV cache [Mismatch]
k_cache = ttnn.to_torch(
tt_model.layers[0].attention.layer_past[0], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)
)
v_cache = ttnn.to_torch(
tt_model.layers[0].attention.layer_past[1], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)
)
for i in range(batch_size):
k_equal = torch.eq(k_cache[0, :, :, :], k_cache[i, :, :, :]).all()
v_equal = torch.eq(v_cache[0, :, :, :], v_cache[i, :, :, :]).all()
if not k_equal:
print("Batch", i, "k_cache not equal")
# print(f"PCC = {comp_pcc(k_cache[0,:,:,:], k_cache[i,:,:,:])}")
if not v_equal:
print("Batch", i, "v_cache not equal")
# print(f"PCC = {comp_pcc(v_cache[0,:,:,:], v_cache[i,:,:,:])}")

# Check rot mats [All equal]
cos_out = ttnn.to_torch(rot_mats[0], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :]
sin_out = ttnn.to_torch(rot_mats[1], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :]

for i in range(batch_size):
cos_equal = torch.eq(cos_out[0, :, :], cos_out[i, :, :]).all()
sin_equal = torch.eq(sin_out[0, :, :], sin_out[i, :, :]).all()
if not cos_equal:
print("Batch", i, "cos not equal")
if not sin_equal:
print("Batch", i, "sin not equal")
###########

# Save output token to print out later
for user in range(batch_size):
user_tok = tt_output_torch[user].tolist()
Expand Down

0 comments on commit 1454ae2

Please sign in to comment.