Skip to content

Commit

Permalink
Revert changes to Falcon7b matmul configs to fix CI tests after defau…
Browse files Browse the repository at this point in the history
…lt matmul configs were modified (#15439)
  • Loading branch information
skhorasganiTT authored Nov 26, 2024
1 parent 94dace1 commit 6ff4c6e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 4 deletions.
6 changes: 6 additions & 0 deletions models/demos/falcon7b_common/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def forward(
memory_config=self.model_config["FUSED_QKV_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["FUSED_QKV_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(hidden_states.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

###########
Expand Down Expand Up @@ -261,6 +262,7 @@ def forward(
query_layer,
key_layer_transposed,
memory_config=self.model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_config["DEFAULT_HiFi2_KERNEL_CONFIG"],
)
query_layer.deallocate()
key_layer_transposed.deallocate()
Expand Down Expand Up @@ -295,6 +297,7 @@ def forward(
attn_weights,
value_layer,
memory_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"],
compute_kernel_config=self.model_config["DEFAULT_HiFi2_KERNEL_CONFIG"],
)
attn_weights.deallocate()
value_layer.deallocate()
Expand All @@ -313,6 +316,7 @@ def forward(
memory_config=self.model_config["SELFOUT_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["SELFOUT_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(attn_output.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

return attn_output, layer_present
Expand Down Expand Up @@ -589,6 +593,7 @@ def forward(
memory_config=self.model_config["FUSED_QKV_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["FUSED_QKV_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(hidden_states.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

###########
Expand Down Expand Up @@ -845,6 +850,7 @@ def forward(
memory_config=self.model_config["SELFOUT_MM_OUTPUT_MEMCFG"],
dtype=self.model_config["SELFOUT_MM_OUTPUT_DTYPE"],
core_grid=get_falcon_default_core_grid(attn_output.device()),
compute_kernel_config=self.model_config["DEFAULT_LoFi_KERNEL_CONFIG"],
)

return attn_output, layer_present
14 changes: 12 additions & 2 deletions models/demos/falcon7b_common/tt/falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from ttnn import ReplicateTensorToMesh
from models.demos.falcon7b_common.tt.falcon_lm_head import falcon_lm_head_matmul_2d
from models.demos.falcon7b_common.tt.falcon_model import TtFalconModelShared
from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid, get_weights_cached
from models.demos.falcon7b_common.tt.model_utils import (
get_falcon_default_core_grid,
get_weights_cached,
get_default_hifi2_kernel_config,
)
from models.demos.falcon7b_common.tests.test_utils import tt_from_torch
from models.utility_functions import (
is_grayskull,
Expand All @@ -27,7 +31,13 @@ def falcon_lm_head_matmul(
seq_len = input_tensor_a.shape.with_tile_padding()[2]
if seq_len > 512:
# TODO: Review if this path is used? If not, we can delete
return ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=output_mem_config, dtype=output_dtype)
return ttnn.matmul(
input_tensor_a,
input_tensor_b,
memory_config=output_mem_config,
dtype=output_dtype,
compute_kernel_config=get_default_hifi2_kernel_config(),
)

if is_grayskull():
compute_kernel_config = ttnn.GrayskullComputeKernelConfig(
Expand Down
14 changes: 12 additions & 2 deletions models/demos/falcon7b_common/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
import torch
import ttnn
from ttnn import ReplicateTensorToMesh
from models.demos.falcon7b_common.tt.model_utils import get_falcon_default_core_grid, get_weights_cached
from models.demos.falcon7b_common.tt.model_utils import (
get_falcon_default_core_grid,
get_weights_cached,
get_default_hifi2_kernel_config,
)
from models.demos.falcon7b_common.tests.test_utils import tt_from_torch
from torch import nn
from models.utility_functions import (
Expand Down Expand Up @@ -58,7 +62,13 @@ def falcon_dense_h_to_4h_matmul(
if seq_len > 1024:
# TODO: Review if this path is used? If not, we can delete
assert fused_activation == None
return ttnn.matmul(input_tensor_a, input_tensor_b, memory_config=output_mem_config, dtype=output_dtype)
return ttnn.matmul(
input_tensor_a,
input_tensor_b,
memory_config=output_mem_config,
dtype=output_dtype,
compute_kernel_config=get_default_hifi2_kernel_config(),
)

if is_grayskull():
compute_kernel_config = ttnn.GrayskullComputeKernelConfig(
Expand Down
17 changes: 17 additions & 0 deletions models/demos/falcon7b_common/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from transformers import FalconConfig
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.demos.falcon7b_common.tt.model_utils import get_default_hifi2_kernel_config

OP_KEYS = (
# Inputs
Expand Down Expand Up @@ -290,6 +291,22 @@ def get_model_config(model_config_str, prefill_seq_len=0, decode_batch_size=32):
model_config["PRE_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config
model_config["POST_SOFTMAX_MM_COMPUTE_KERNEL_CONFIG"] = gs_compute_kernel_config

if is_wormhole_b0():
default_lofi_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
default_lofi_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.LoFi,
math_approx_mode=True,
)
model_config["DEFAULT_LoFi_KERNEL_CONFIG"] = default_lofi_kernel_config

model_config["DEFAULT_HiFi2_KERNEL_CONFIG"] = get_default_hifi2_kernel_config()

# uncomment if need to see all the configs
# logger.debug(f"Falcon model config: \n{pretty_print_model_config(model_config)}")
set_prefill_config(model_config, prefill_seq_len, DRAM_MEMCFG)
Expand Down
16 changes: 16 additions & 0 deletions models/demos/falcon7b_common/tt/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,22 @@ def get_falcon_default_core_grid(device):
return ttnn.CoreGrid(y=grid_size.y, x=grid_size.x)


def get_default_hifi2_kernel_config():
if is_wormhole_b0():
hifi2_kernel_config = ttnn.WormholeComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=False,
fp32_dest_acc_en=False,
packer_l1_acc=False,
)
else:
hifi2_kernel_config = ttnn.GrayskullComputeKernelConfig(
math_fidelity=ttnn.MathFidelity.HiFi2,
math_approx_mode=True,
)
return hifi2_kernel_config


def layernorm(ln_input, ln_eps, ln_gamma, ln_betta, model_config):
h_dim = ln_input.shape.with_tile_padding()[-2] # corresponds to batch size (decode) or seq_len (prefill)
if h_dim in [32, 128, 256, 1024, 2048]:
Expand Down

0 comments on commit 6ff4c6e

Please sign in to comment.