Skip to content

Commit

Permalink
Add reduce scatter to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 27, 2024
1 parent 70d6dc0 commit 96fd763
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 26 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tg-model-perf-tests-impl.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ jobs:
runs-on: ["arch-wormhole_b0", "config-tg", "in-service", "bare-metal", "pipeline-perf"],
cmd: './tests/scripts/run_tests.sh --tt-arch wormhole_b0 --pipeline-type cnn_model_perf_tg_device --dispatch-mode ""'
},
{ name: "t3k CCL all_gather perf tests",
{ name: "t3k CCL perf tests",
arch: wormhole_b0,
cmd: './tests/scripts/run_tests.sh --tt-arch wormhole_b0 --pipeline-type ccl_all_gather_perf_tg_device --dispatch-mode ""',
cmd: './tests/scripts/run_tests.sh --tt-arch wormhole_b0 --pipeline-type ccl_perf_tg_device --dispatch-mode ""',
timeout: 75,
tracy: true,
runs-on: ["arch-wormhole_b0", "config-tg", "in-service", "bare-metal", "pipeline-perf"],
Expand Down
3 changes: 2 additions & 1 deletion tests/scripts/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,9 @@ run_pipeline_tests() {
demos_tg_device "$tt_arch" "$pipeline_type" "$dispatch_mode"
elif [[ $pipeline_type == *"model_perf_tg_device" ]]; then
model_perf_tg_device "$tt_arch" "$pipeline_type" "$dispatch_mode"
elif [[ $pipeline_type == "ccl_all_gather_perf_tg_device" ]]; then
elif [[ $pipeline_type == "ccl_perf_tg_device" ]]; then
./tests/ttnn/unit_tests/operations/ccl/perf/run_all_gather_profile.sh -t tg
./tests/ttnn/unit_tests/operations/ccl/perf/run_reduce_scatter_profile.sh -t tg
# TGG pipelines
elif [[ $pipeline_type == "unit_tgg_device" ]]; then
unit_tgg_device "$tt_arch" "$pipeline_type" "$dispatch_mode"
Expand Down
2 changes: 1 addition & 1 deletion tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def test_all_gather_on_tg(
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True)
def test_line_reduce_scatter_on_TG_rows_post_commit(
def test_reduce_scatter_on_tg(
mesh_device,
num_devices,
per_chip_output_shape,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def run_with_trace(
mesh_device,
all_gather_topology,
input_tensor,
scatter_dim,
dim,
num_links,
math_op,
cluster_axis,
output_mem_config,
n_worker=None,
Expand All @@ -60,8 +61,8 @@ def run_with_trace(
# Compile Run
logger.info("Compiling model")
tt_out_tensor = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=scatter_dim,
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
Expand All @@ -77,8 +78,8 @@ def run_with_trace(
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
tt_out_tensor = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=scatter_dim,
input_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
Expand Down Expand Up @@ -198,22 +199,9 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
)
ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
for _ in range(num_iters):
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
if trace_mode:
ttnn_tensor_out = run_with_trace(
input_tensor=ttnn_tensor,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
Expand All @@ -227,7 +215,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
for _ in range(num_iters):
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=dim,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
Expand Down

0 comments on commit 96fd763

Please sign in to comment.