Skip to content

Commit

Permalink
#0: Apply sharded residual only in decode mode
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 20, 2024
1 parent 1a1e4ae commit 4b048cf
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion models/demos/llama3/tt/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def forward(
get_last_token=-1,
):
# No-op if callers already provide the right memory config
x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"])
if mode == "decode":
x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"])

for layer in self.layers:
x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)
Expand Down

0 comments on commit 4b048cf

Please sign in to comment.