Skip to content

Commit

Permalink
Add testing for program caching and relu activation.
Browse files Browse the repository at this point in the history
  • Loading branch information
avoraTT committed Nov 21, 2024
1 parent ca82556 commit 17cfe25
Showing 1 changed file with 22 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ def get_physical_to_logical_core_mapping(device):
@pytest.mark.parametrize(
"B, M, K, N, in0_dtype, in1_dtype, fidelity, packer_l1_acc, fp32_acc_mode, grid",
[
# 32, 2304, 3840 (PREFETCHER), only works on TG
(1, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, PREFETCHER_GRID),
# # 32, 2304, 3840 (PREFETCHER), only works on TG
# (1, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, PREFETCHER_GRID),
# 32, 2304, 3840
(1, 32, 2304, 3840, ttnn.bfloat16, ttnn.bfloat4_b, ttnn.MathFidelity.LoFi, True, True, (8, 3)),
# 32, 2304, 3840
Expand Down Expand Up @@ -116,12 +116,17 @@ def get_physical_to_logical_core_mapping(device):
[
None,
ttnn.UnaryOpType.SILU,
ttnn.UnaryOpType.RELU,
],
)
@pytest.mark.parametrize(
"use_arbitrary_cores",
[False, True],
)
@pytest.mark.parametrize(
"num_iters",
[1, 3],
)
def test_multi_core_matmul_1d_wh(
device,
in0_dtype,
Expand All @@ -137,6 +142,8 @@ def test_multi_core_matmul_1d_wh(
activation,
grid,
use_arbitrary_cores,
num_iters,
use_program_cache,
function_level_defaults,
):
assert not has_bias, "Bias not supported for gather_in0 mode."
Expand Down Expand Up @@ -276,19 +283,24 @@ def test_multi_core_matmul_1d_wh(
dst_full_sync_en=True,
)

output_t = ttnn.matmul(
in0_t,
in1_t,
program_config=program_config,
memory_config=output_sharded_mem_config,
compute_kernel_config=compute_kernel_config,
)
for _ in range(num_iters):
output_t = ttnn.matmul(
in0_t,
in1_t,
program_config=program_config,
memory_config=output_sharded_mem_config,
compute_kernel_config=compute_kernel_config,
)
tt_out = ttnn.to_torch(output_t)

pt_out = in0 @ in1
if activation:
pt_out = torch.nn.SiLU()(pt_out)
act_fnc = torch.nn.functional.silu if activation == ttnn.UnaryOpType.SILU else torch.nn.functional.relu
pt_out = act_fnc(pt_out)

passing, output = comp_pcc(pt_out, tt_out)
logger.info(output)
assert passing

# Check program cache
assert device.num_program_cache_entries() == 1 # Only 1 op

0 comments on commit 17cfe25

Please sign in to comment.