Skip to content

Commit

Permalink
#11512: Refactor hardtanh.py sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
amalbasaTT committed Nov 22, 2024
1 parent ff07e81 commit aa7d918
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ttnn-run-sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ on:
- eltwise.unary.hardtanh.hardtanh
- eltwise.unary.hardtanh.hardtanh_sharded
- eltwise.unary.hardswish.hardswish
- eltwise.unary.hardsigmoid.hardswish_sharded
- eltwise.unary.hardswish.hardswish_sharded
- eltwise.unary.hardsigmoid.hardsigmoid
- eltwise.unary.hardsigmoid.hardsigmoid_sharded
- eltwise.unary.hardshrink.hardshrink
Expand Down
13 changes: 8 additions & 5 deletions tests/sweep_framework/sweeps/eltwise/unary/hardtanh/hardtanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
import ttnn
from tests.sweep_framework.sweep_utils.utils import gen_shapes
from tests.sweep_framework.sweep_utils.utils import gen_shapes, sanitize_shape_rm
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
Expand All @@ -20,9 +20,9 @@
# Developers can create their own generator functions and pass them to the parameters as inputs.
parameters = {
"nightly": {
"input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16)
+ gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16)
+ gen_shapes([32, 32], [256, 256], [32, 32], 16),
"input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16)
+ gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16)
+ gen_shapes([1, 1], [256, 256], [1, 1], 16),
"input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b],
"input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT],
"input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG],
Expand Down Expand Up @@ -55,6 +55,9 @@ def run(
) -> list:
torch.manual_seed(0)

if input_layout == ttnn.ROW_MAJOR_LAYOUT:
input_shape = sanitize_shape_rm(input_shape)

torch_input_tensor_a = gen_func_with_cast_tt(
partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype
)(input_shape)
Expand All @@ -72,7 +75,7 @@ def run(

start_time = start_measuring_time()
output_tensor = ttnn.hardtanh(input_tensor_a, memory_config=output_memory_config)
output_tensor = ttnn.to_torch(output_tensor)
e2e_perf = stop_measuring_time(start_time)
output_tensor = ttnn.to_torch(output_tensor)

return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf]

0 comments on commit aa7d918

Please sign in to comment.