Skip to content

Commit

Permalink
Add padding to position ids to support rope with batch < 32 in trace …
Browse files Browse the repository at this point in the history
…mode. TODO: Debug inconsistent outputs of batch 1 vs batch 16/32
  • Loading branch information
avoraTT authored and mtairum committed Nov 25, 2024
1 parent 50362ec commit a7e50dd
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 39 deletions.
14 changes: 10 additions & 4 deletions models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from pathlib import Path
import hashlib

from models.utility_functions import nearest_32
from models.demos.llama3.tt.llama_common import (
get_prefill_rot_mat,
get_rot_transformation_mat,
Expand Down Expand Up @@ -407,6 +408,10 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_

# Initial positions
current_pos = torch.tensor([decoding_pos[b] for b in range(batch_size)])

pad_size = nearest_32(batch_size) - batch_size
current_pos_padded = torch.nn.functional.pad(current_pos, (0, pad_size), "constant", 0)

current_pos_tensor = ttnn.from_torch(
current_pos,
device=mesh_device,
Expand All @@ -415,8 +420,8 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_
)

# Get cos/sin matrices for the current position of each user
rot_mats = rope_setup.get_rot_mats(current_pos)
rot_mat_idxs = rope_setup.get_rot_idxs(current_pos)
rot_mats = rope_setup.get_rot_mats(current_pos_padded)
rot_mat_idxs = rope_setup.get_rot_idxs(current_pos_padded)

# Compile
logger.info(f"Compiling model trace...")
Expand Down Expand Up @@ -490,7 +495,7 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_
# Reset the current position and output token tensors for the real decode run
ttnn.copy_host_to_device_tensor(current_pos_reset, current_pos_tensor)
ttnn.copy_host_to_device_tensor(tt_out_tok_reset, tt_out_tok)
rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos, on_host=True)
rot_mat_idxs_reset = rope_setup.get_rot_idxs(current_pos_padded, on_host=True)
ttnn.copy_host_to_device_tensor(rot_mat_idxs_reset, rot_mat_idxs)

profiler.end(f"capture_trace_{batch_idx}")
Expand Down Expand Up @@ -519,7 +524,8 @@ def run_llama3_demo(user_input, single_layer, mesh_device, instruct_mode, is_ci_
# TODO This is required for now since we cannot ttnn.plus_one(rot_mat_idxs) while it being uint32.
# If this tensor is int32, it won't be supported by ttnn.embedding
current_pos += 1
rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos, on_host=True)
current_pos_padded += 1
rot_mat_idxs_updated = rope_setup.get_rot_idxs(current_pos_padded, on_host=True)
ttnn.copy_host_to_device_tensor(rot_mat_idxs_updated, rot_mat_idxs)

# Write to host
Expand Down
22 changes: 9 additions & 13 deletions models/demos/llama3/tt/llama_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def get_rot_idxs(self, position_idxs, on_host=False):
assert len(position_idxs.shape) == 1, "position idxs must be a [batch] tensor"

batch = position_idxs.shape[0]
position_idxs = position_idxs.reshape(1, 1, 1, batch) # [1, 1, 1, batch]
assert position_idxs.shape == (1, 1, 1, batch), "position idxs must be a [1, batch] tensor"
position_idxs = position_idxs.reshape(1, batch) # [1, 1, 1, batch]
assert position_idxs.shape == (1, batch), "position idxs must be a [1, batch] tensor"
assert torch.min(position_idxs) >= 0, "position idxs must be non-negative"

if on_host:
Expand Down Expand Up @@ -125,20 +125,15 @@ def get_rot_mats(self, position_idxs, return_rot_idxs=False):
rot_idxs = self.get_rot_idxs(position_idxs)
else:
rot_idxs = position_idxs
assert rot_idxs.shape == [1, 1, 1, self.batch_size], "rot_idxs must be a [1, 1, 1, batch] tensor"
assert len(rot_idxs.shape) == 2 and rot_idxs.shape == [
1,
rot_idxs.shape[1],
], "rot_idxs must be a [1, batch] tensor"

# Send the idxs to device
if rot_idxs.device != device:
rot_idxs = ttnn.to_device(rot_idxs, device, memory_config=ttnn.DRAM_MEMORY_CONFIG)

batch = rot_idxs.shape[3]

# Pad the batch dimension to be a multiple of TILE_SIZE
if batch % ttnn.TILE_SIZE != 0:
pad_size = ttnn.TILE_SIZE - (batch % ttnn.TILE_SIZE)
rot_idxs = ttnn.pad(rot_idxs, [1, 1, 1, batch + pad_size], [0, 0, 0, batch], 0.0)
batch = rot_idxs.shape[3]

embedding_layout = ttnn.TILE_LAYOUT
cos = ttnn.embedding(rot_idxs, self.cos_matrix, layout=embedding_layout) # [1, batch, head_dim]
sin = ttnn.embedding(rot_idxs, self.sin_matrix, layout=embedding_layout) # [1, batch, head_dim]
Expand All @@ -149,8 +144,9 @@ def get_rot_mats(self, position_idxs, return_rot_idxs=False):
cos = ttnn.transpose(cos, 1, 2) # [1, batch, 1[32], head_dim]
sin = ttnn.transpose(sin, 1, 2) # [1, batch, 1[32], head_dim]

cos = cos[:, : self.batch_size, :, :]
sin = sin[:, : self.batch_size, :, :]
if self.batch_size % ttnn.TILE_SIZE != 0:
cos = cos[:, : self.batch_size, :, :]
sin = sin[:, : self.batch_size, :, :]

grid = ttnn.num_cores_to_corerangeset(self.batch_size, self.core_grid, row_wise=True)
mem_config = ttnn.create_sharded_memory_config(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ def run_test_rotary_embedding_llama(

position_ids = torch.arange(batch) if mode == "decode" else slice(start_pos, start_pos + seq_len)

if mode == "decode": # Pad position_ids to batch size for decode mode
if fuse_qk:
# For fused_qk, repeat the position_ids for q and k
position_ids_padded = torch.cat([position_ids, position_ids])
pad_size = nearest_32(batch * 2) - batch * 2
position_ids_padded = torch.nn.functional.pad(position_ids_padded, (0, pad_size), "constant", 0)
else:
pad_size = nearest_32(batch) - batch
position_ids_padded = torch.nn.functional.pad(position_ids, (0, pad_size), "constant", 0)

freqs_cis = freqs_cis[position_ids]

# PyTorch Ground Truth output --------------------------------------------------------------------
Expand All @@ -173,18 +183,16 @@ def run_test_rotary_embedding_llama(
tt_model = TtLlamaRotary(device, head_dim, mode, datatype, fuse_qk)

if mode == "decode":
rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len)
cos, sin = rope_setup_decode.get_rot_mats(position_ids)
tt_model.transformation_mat = rope_setup_decode.transformation_mat

# For decode, TTNN expects inputs to be [1, batch, nh, dhead]
inp = [x.transpose(1, 2) for x in inp]
# inp: [seq_len, batch, n_heads, head_dim]

if fuse_qk:
# For fused_qk, repeat the position_ids for q and k
position_ids = torch.concat([position_ids, position_ids])
cos, sin = rope_setup_decode.get_rot_mats(position_ids)
# Set up rope with 2 * batch size (for fused qk)
rope_setup_decode = TtLlamaRotarySetup(device, batch * 2, head_dim, max_seq_len)
tt_model.transformation_mat = rope_setup_decode.transformation_mat
cos, sin = rope_setup_decode.get_rot_mats(position_ids_padded)

assert (
batch % 8 == 0 or batch == 1
), "Batch size must be a multiple of 8 or less than 8 for fused_qk rotary embedding"
Expand Down Expand Up @@ -219,18 +227,19 @@ def run_test_rotary_embedding_llama(
input_mem_configs = [q_input_mem_config, k_input_mem_config]

else:
cos, sin = rope_setup_decode.get_rot_mats(position_ids)
grid = (
ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True)
.bounding_box()
.grid_size()
)
# Set up rope with batch size
rope_setup_decode = TtLlamaRotarySetup(device, batch, head_dim, max_seq_len)
tt_model.transformation_mat = rope_setup_decode.transformation_mat
cos, sin = rope_setup_decode.get_rot_mats(position_ids_padded)

grid = ttnn.num_cores_to_corerangeset(batch, rope_setup_decode.core_grid, row_wise=True)
input_mem_configs = [
ttnn.create_sharded_memory_config(
shape=(1, batch, ttnn.TILE_SIZE, head_dim),
core_grid=ttnn.CoreGrid(y=grid.y, x=grid.x),
shape=(ttnn.TILE_SIZE, head_dim),
core_grid=grid,
strategy=ttnn.ShardStrategy.HEIGHT,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
)
for _ in range(len(inp))
]
Expand Down Expand Up @@ -448,10 +457,9 @@ def test_rotary_embedding_llama_with_program_cache(

num_ops = 2 # 2 * rope
if mode == "decode":
num_ops += 4 # embedding + transpose + pad + interleaved_to_sharded
num_ops += 3 # embedding + transpose + interleaved_to_sharded

# Extra ops to pad batch to tile size
if batch % ttnn.TILE_SIZE != 0:
num_ops += 2 # pad + slice
num_ops += 1 # slice

assert device.num_program_cache_entries() == num_ops
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,9 @@ def test_rotary_embedding_llama_fused_qk_with_program_cache(

cache_tensors.append(test_tensor)

if batch == 32 or batch == 16:
num_ops = 4
else:
num_ops = 5 # embedding + fused_qk_rope + transpose + pad + interleaved_to_sharded
num_ops = 4 # embedding + fused_qk_rope + transpose + interleaved_to_sharded

if (batch * 2) % ttnn.TILE_SIZE != 0:
num_ops += 1 # slice

assert device.num_program_cache_entries() == num_ops

0 comments on commit a7e50dd

Please sign in to comment.