From 4b048cf3542b057a957bcf62e3f00c544d2bf39c Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Wed, 20 Nov 2024 11:47:04 +0000 Subject: [PATCH] #0: Apply sharded residual only in decode mode --- models/demos/llama3/tt/llama_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 5b1b5a49b3bd..04cf2c8d77be 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -84,7 +84,8 @@ def forward( get_last_token=-1, ): # No-op if callers already provide the right memory config - x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) + if mode == "decode": + x = ttnn.to_memory_config(x, self.model_config["DECODE_RESIDUAL_MEMCFG"]) for layer in self.layers: x = layer(x, current_pos, rot_mat, transformation_mats, user_id, mode, page_table)