Skip to content

Commit

Permalink
#0: fixed pcc issue for sharded all_gather after fused qkv (#14138)
Browse files Browse the repository at this point in the history
#0: fixed pcc issue for sharded all_gather after fused qkv
  • Loading branch information
kpaigwar authored Oct 23, 2024
1 parent 25ccab3 commit ec805b3
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 19 deletions.
7 changes: 5 additions & 2 deletions models/demos/tg/llama3_70b/tests/test_llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,12 @@ def tt_llama_decoder_prepare_inputs(llama_decoder_model, x, start_pos, mode):
assert seq_len == 1, "Only supporting decode mode"
x = x.transpose(0, 1).unsqueeze(1) # [seq_len, 1, batch, hidden_dim]

ACT_CORE_GRID_Y = llama_decoder_model.model_config["DECODE_ACT_CORE_GRID"][0]
ACT_CORE_GRID_X = llama_decoder_model.model_config["DECODE_ACT_CORE_GRID"][1]
ACT_CORE_GRID_SIZE = ACT_CORE_GRID_Y * ACT_CORE_GRID_X
ACT_MEMCFG = ttnn.create_sharded_memory_config(
shape=(x.shape[2], x.shape[3] // 8 // llama_decoder_model.cluster_shape[0]),
core_grid=ttnn.CoreGrid(y=1, x=8),
shape=(x.shape[2], x.shape[3] // ACT_CORE_GRID_SIZE // llama_decoder_model.cluster_shape[0]),
core_grid=ttnn.CoreGrid(y=ACT_CORE_GRID_Y, x=ACT_CORE_GRID_X),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
Expand Down
15 changes: 7 additions & 8 deletions models/demos/tg/llama3_70b/tt/llama_attention_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,17 @@ def attn_qkv(
self.qkv,
program_config=self.attention_config["FUSED_QKV_MM_PROGCFG"],
dtype=ttnn.bfloat16,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
compute_kernel_config=self.attention_config["COMPUTE_KERNEL_QKV"],
)
xs.deallocate(True)

# TODO: Use sharded all_reduce when PCC issue is fixed in this particular configuration
# fused_query_key_value = tt_sharded_all_reduce(
# fused_query_key_value, self.mesh_device, cluster_axis=1, num_links=2, memory_config=self.attention_config["QKV_OUT_GATHERED_MEMCFG"](self.cluster_shape[0])
# )

fused_query_key_value = tt_all_reduce(
fused_query_key_value, self.mesh_device, cluster_axis=1, num_links=2, memory_config=ttnn.DRAM_MEMORY_CONFIG
fused_query_key_value = tt_sharded_all_reduce(
fused_query_key_value,
self.mesh_device,
cluster_axis=1,
num_links=2,
memory_config=self.attention_config["QKV_OUT_GATHERED_MEMCFG"](self.cluster_shape[0]),
)

# TODO: Slice the fused_query_key_value tensor get batch=8
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def tt_sharded_all_reduce(input_tensor, mesh_device, cluster_axis, dim=0, num_li
# Fast_reduce_nc does not support sharded memory configuration, convert to interleaved
gathered_tensor = ttnn.to_memory_config(gathered_tensor, ttnn.L1_MEMORY_CONFIG)
reduced_tensors = ttnn.experimental.fast_reduce_nc(
gathered_tensor, dims=[dim], output=None, compute_kernel_config=None
gathered_tensor, dims=[dim], output=None, compute_kernel_config=None, memory_config=ttnn.L1_MEMORY_CONFIG
)
return reduced_tensors

Expand Down
7 changes: 3 additions & 4 deletions models/demos/tg/llama3_70b/tt/llama_decoder_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,11 @@ def decode_forward(
ln_sharded_stats_memcfg=self.decoder_config["LN_SHARDED_STATS_MEMCFG"],
)

attn_norm_out = ttnn.to_memory_config(attn_norm_out, memory_config=self.decoder_config["ATTN_ACT_MEMCFG"])
attn_outs = self.attention(attn_norm_out, rot_mats, start_pos, attn_masks, mode="decode")
attn_outs = ttnn.to_memory_config(attn_outs, memory_config=self.decoder_config["MLP_ACT_MEMCFG"])
attn_outs = ttnn.to_memory_config(attn_outs, memory_config=self.decoder_config["ATTN_ACT_MEMCFG"])

output = xs
output = ttnn.add(
output,
xs,
attn_outs,
memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
)
Expand All @@ -186,6 +184,7 @@ def decode_forward(
ffn_norm_out = ttnn.to_memory_config(ffn_norm_out, memory_config=self.decoder_config["MLP_ACT_MEMCFG"])
ffn_out = self.mlp(ffn_norm_out, mode="decode")
ffn_norm_out.deallocate(True)
ffn_out = ttnn.to_memory_config(ffn_out, memory_config=self.decoder_config["ATTN_ACT_MEMCFG"])
### residual add
output = ttnn.add(
output,
Expand Down
2 changes: 0 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_mlp_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,6 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:
memory_config=self.mlp_config["FF2_OUT_GATHERED_MEMCFG"],
)

hidden_states = ttnn.to_memory_config(hidden_states, self.mlp_config["FF1_ACT_MEMCFG"])

return hidden_states

def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:
Expand Down
7 changes: 5 additions & 2 deletions models/demos/tg/llama3_70b/tt/llama_model_galaxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ def prepare_inputs(self, inp_ids, start_pos, valid_seq_len=None, attn_mask=None,
assert seq_len == 1, "Decode mode only supports seq_len=1"
assert xs.shape == (seq_len, 1, batch, self.hidden_size // self.cluster_shape[0])

ACT_CORE_GRID_Y = self.model_config["DECODE_ACT_CORE_GRID"][0]
ACT_CORE_GRID_X = self.model_config["DECODE_ACT_CORE_GRID"][1]
ACT_CORE_GRID_SIZE = ACT_CORE_GRID_Y * ACT_CORE_GRID_X
ACT_MEMCFG = ttnn.create_sharded_memory_config(
shape=(xs.shape[2], xs.shape[3] // 8),
core_grid=ttnn.CoreGrid(y=1, x=8),
shape=(xs.shape[2], xs.shape[3] // ACT_CORE_GRID_SIZE),
core_grid=ttnn.CoreGrid(y=ACT_CORE_GRID_Y, x=ACT_CORE_GRID_X),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
use_height_and_width_as_shard_shape=True,
Expand Down
1 change: 1 addition & 0 deletions models/demos/tg/llama3_70b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def get_model_config(llama_version="llama3-tg", max_batch_size=32, max_context_l
fp32_dest_acc_en=True,
packer_l1_acc=True,
),
"DECODE_ACT_CORE_GRID": (4, 8),
}

if llama_version == "llama3" or llama_version == "llama3-tg":
Expand Down

0 comments on commit ec805b3

Please sign in to comment.