diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index de80ee2961d..0ce7216353c 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -449,8 +449,6 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) decode_input = ttnn.unsqueeze_to_4D(tt_embd(tt_out_tok)) - # TODO Miguel: I think the problem is here, not updating the get rot mats - # The problem is that the get_rot_mats is using embedding that ends up on the host. rot_mats = rope_setup.get_rot_mats(rot_mat_idxs) tt_out = tt_model( decode_input, @@ -532,51 +530,6 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ )[0, 0, 0, :batch_size] ttnn.record_event(1, write_event) - # TODO Miguel Remove - print("==== ITERATION", iteration, "====") - # Check input - input_torch = ttnn.to_torch(decode_input, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=3)) - for i in range(batch_size): - input_equal = torch.eq(input_torch[:, :, 0, :], input_torch[:, :, i, :]).all() - if not input_equal: - print("Batch", i, "input not equal") - - # Check output - for i in range(batch_size): - out_equal = torch.eq(tt_output_torch[0], tt_output_torch[i]) - if not out_equal: - print("Batch", i, "output not equal") - - # Check KV cache [Mismatch] - k_cache = ttnn.to_torch( - tt_model.layers[0].attention.layer_past[0], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1) - ) - v_cache = ttnn.to_torch( - tt_model.layers[0].attention.layer_past[1], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1) - ) - for i in range(batch_size): - k_equal = torch.eq(k_cache[0, :, :, :], k_cache[i, :, :, :]).all() - v_equal = torch.eq(v_cache[0, :, :, :], v_cache[i, :, :, :]).all() - if not k_equal: - print("Batch", i, "k_cache not equal") - # print(f"PCC = {comp_pcc(k_cache[0,:,:,:], k_cache[i,:,:,:])}") - if not v_equal: - print("Batch", i, "v_cache not equal") - # print(f"PCC = {comp_pcc(v_cache[0,:,:,:], v_cache[i,:,:,:])}") - - # Check rot mats [All equal] - cos_out = ttnn.to_torch(rot_mats[0], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] - sin_out = ttnn.to_torch(rot_mats[1], mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))[0, :, :, :] - - for i in range(batch_size): - cos_equal = torch.eq(cos_out[0, :, :], cos_out[i, :, :]).all() - sin_equal = torch.eq(sin_out[0, :, :], sin_out[i, :, :]).all() - if not cos_equal: - print("Batch", i, "cos not equal") - if not sin_equal: - print("Batch", i, "sin not equal") - ########### - # Save output token to print out later for user in range(batch_size): user_tok = tt_output_torch[user].tolist()