diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 756c60c444b8..8aaabc443fb1 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -15,6 +15,7 @@ from pathlib import Path import hashlib +from models.utility_functions import nearest_32 from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, @@ -407,6 +408,10 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ # Initial positions current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)]) + + pad_size = nearest_32(batch_size) - batch_size + current_pos_padded = torch.nn.functional.pad(current_pos, (0, pad_size), "constant", 0) + current_pos_tensor = ttnn.from_torch( current_pos, device=mesh_device, @@ -415,8 +420,8 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ ) # Get cos/sin matrices for the current position of each user - rot_mats = rope_setup.get_rot_mats(current_pos) - rot_mat_idxs = rope_setup.get_rot_idxs(current_pos) + rot_mats = rope_setup.get_rot_mats(current_pos_padded) + rot_mat_idxs = rope_setup.get_rot_idxs(current_pos_padded) # Compile logger.info(f"Compiling model trace...") @@ -490,7 +495,7 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ # Reset the current position and output token tensors for the real decode run ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor) ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok) - rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True) + rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos_padded, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs) profiler.end(f"capture_trace_{batch_idx}") @@ -519,7 +524,8 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_ # TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32. # If this tensor is int32, it won't be supported by ttnn.embedding current_pos += 1 - rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True) + current_pos_padded += 1 + rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos_padded, on_host=True) ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs) # Write to host diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index d4e03c23a932..7b6caddd4663 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -94,8 +94,8 @@ def get_rot_idxs(self, position_idxs, on_host=False): assert len(position_idxs.shape) == 1, "position idxs must be a [batch] tensor" batch = position_idxs.shape[0] - position_idxs = position_idxs.reshape(1, 1, 1, batch) # [1, 1, 1, batch] - assert position_idxs.shape == (1, 1, 1, batch), "position idxs must be a [1, batch] tensor" + position_idxs = position_idxs.reshape(1, batch) # [1, 1, 1, batch] + assert position_idxs.shape == (1, batch), "position idxs must be a [1, batch] tensor" assert torch.min(position_idxs) >= 0, "position idxs must be non-negative" if on_host: @@ -125,20 +125,15 @@ def get_rot_mats(self, position_idxs, return_rot_idxs=False): rot_idxs = self.get_rot_idxs(position_idxs) else: rot_idxs = position_idxs - assert rot_idxs.shape == [1, 1, 1, self.batch_size], "rot_idxs must be a [1, 1, 1, batch] tensor" + assert len(rot_idxs.shape) == 2 and rot_idxs.shape == [ + 1, + rot_idxs.shape[1], + ], "rot_idxs must be a [1, batch] tensor" # Send the idxs to device if rot_idxs.device != device: rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG) - batch = rot_idxs.shape[3] - - # Pad the batch dimension to be a multiple of TILE_SIZE - if batch % ttnn.TILE_SIZE != 0: - pad_size = ttnn.TILE_SIZE - (batch % ttnn.TILE_SIZE) - rot_idxs = ttnn.pad(rot_idxs, [1, 1, 1, batch + pad_size], [0, 0, 0, batch], 0.0) - batch = rot_idxs.shape[3] - embedding_layout = ttnn.TILE_LAYOUT cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim] sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim] @@ -149,8 +144,9 @@ def get_rot_mats(self, position_idxs, return_rot_idxs=False): cos = ttnn.transpose(cos, 1, 2) # [1, batch, 1[32], head_dim] sin = ttnn.transpose(sin, 1, 2) # [1, batch, 1[32], head_dim] - cos = cos[:, : self.batch_size, :, :] - sin = sin[:, : self.batch_size, :, :] + if self.batch_size % ttnn.TILE_SIZE != 0: + cos = cos[:, : self.batch_size, :, :] + sin = sin[:, : self.batch_size, :, :] grid = ttnn.num_cores_to_corerangeset(self.batch_size, self.core_grid, row_wise=True) mem_config = ttnn.create_sharded_memory_config( diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py index 590a02ac9339..541dda21ae3c 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama.py @@ -156,6 +156,16 @@ def run_test_rotary_embedding_llama( position_ids = torch.arange(batch) if mode == "decode" else slice(start_pos, start_pos + seq_len) + if mode == "decode": # Pad position_ids to batch size for decode mode + if fuse_qk: + # For fused_qk, repeat the position_ids for q and k + position_ids_padded = torch.cat([position_ids, position_ids]) + pad_size = nearest_32(batch * 2) - batch * 2 + position_ids_padded = torch.nn.functional.pad(position_ids_padded, (0, pad_size), "constant", 0) + else: + pad_size = nearest_32(batch) - batch + position_ids_padded = torch.nn.functional.pad(position_ids, (0, pad_size), "constant", 0) + freqs_cis = freqs_cis[position_ids] # PyTorch Ground Truth output -------------------------------------------------------------------- @@ -173,18 +183,16 @@ def run_test_rotary_embedding_llama( tt_model = TtLlamaRotary(device, head_dim, mode, datatype, fuse_qk) if mode == "decode": - rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len) - cos, sin = rope_setup_decode.get_rot_mats(position_ids) - tt_model.transformation_mat = rope_setup_decode.transformation_mat - # For decode, TTNN expects inputs to be [1, batch, nh, dhead] inp = [x.transpose(1, 2) for x in inp] # inp: [seq_len, batch, n_heads, head_dim] if fuse_qk: - # For fused_qk, repeat the position_ids for q and k - position_ids = torch.concat([position_ids, position_ids]) - cos, sin = rope_setup_decode.get_rot_mats(position_ids) + # Set up rope with 2 * batch size (for fused qk) + rope_setup_decode = TtLlamaRotarySetup(device, batch * 2, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat + cos, sin = rope_setup_decode.get_rot_mats(position_ids_padded) + assert ( batch % 8 == 0 or batch == 1 ), "Batch size must be a multiple of 8 or less than 8 for fused_qk rotary embedding" @@ -219,18 +227,19 @@ def run_test_rotary_embedding_llama( input_mem_configs = [q_input_mem_config, k_input_mem_config] else: - cos, sin = rope_setup_decode.get_rot_mats(position_ids) - grid = ( - ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) - .bounding_box() - .grid_size() - ) + # Set up rope with batch size + rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len) + tt_model.transformation_mat = rope_setup_decode.transformation_mat + cos, sin = rope_setup_decode.get_rot_mats(position_ids_padded) + + grid = ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True) input_mem_configs = [ ttnn.create_sharded_memory_config( - shape=(1, batch, ttnn.TILE_SIZE, head_dim), - core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x), + shape=(ttnn.TILE_SIZE, head_dim), + core_grid=grid, strategy=ttnn.ShardStrategy.HEIGHT, orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, ) for _ in range(len(inp)) ] @@ -448,10 +457,9 @@ def test_rotary_embedding_llama_with_program_cache( num_ops = 2 # 2 * rope if mode == "decode": - num_ops += 4 # embedding + transpose + pad + interleaved_to_sharded + num_ops += 3 # embedding + transpose + interleaved_to_sharded - # Extra ops to pad batch to tile size if batch % ttnn.TILE_SIZE != 0: - num_ops += 2 # pad + slice + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py index 579791f0eabd..893fe74baa58 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_rotary_embedding_llama_fused_qk.py @@ -132,9 +132,9 @@ def test_rotary_embedding_llama_fused_qk_with_program_cache( cache_tensors.append(test_tensor) - if batch == 32 or batch == 16: - num_ops = 4 - else: - num_ops = 5 # embedding + fused_qk_rope + transpose + pad + interleaved_to_sharded + num_ops = 4 # embedding + fused_qk_rope + transpose + interleaved_to_sharded + + if (batch * 2) % ttnn.TILE_SIZE != 0: + num_ops += 1 # slice assert device.num_program_cache_entries() == num_ops