Skip to content

Commit

Permalink
Add reduce scatter perf to tg
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 27, 2024
1 parent a55d744 commit b53b929
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ show_help() {
echo
echo "Options:"
echo " -d, --debug Enable debug mode to show real-time output."
echo " -t, --target Specify the target configuration (t3000 or n300). Default is n300."
echo " -t, --target Specify the target configuration (t3000 or n300 or tg). Default is n300."
echo " -h, --help Display this help message."
echo
echo "Example:"
Expand Down Expand Up @@ -42,8 +42,8 @@ while [ $# -gt 0 ]; do
shift 2

# Validate the target value
if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "n300" ]; then
echo "Error: Invalid target configuration: $TARGET. Must be either 't3000' or 'n300'."
if [ "$TARGET" != "t3000" ] && [ "$TARGET" != "tg" ] && [ "$TARGET" != "n300" ]; then
echo "Error: Invalid target configuration: $TARGET. Must be either 't3000', 'n300', 'tg."
exit 1
fi
;;
Expand Down
68 changes: 68 additions & 0 deletions tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from tests.ttnn.unit_tests.operations.ccl.test_all_gather_TG_post_commit import (
run_line_all_gather_on_TG_with_mesh_tensor_along_rows,
)
from tests.ttnn.unit_tests.operations.ccl.test_reduce_scatter_TG_nightly import (
run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows,
)


@skip_for_grayskull("Requires eth connected devices to run")
Expand Down Expand Up @@ -332,3 +335,68 @@ def test_all_gather_on_tg(
cluster_axis=1,
trace_mode=True,
)


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.parametrize(
"num_devices, num_links, per_chip_output_shape, dim, layout",
[
(4, 2, [1, 4, 32, 2304], 1, ttnn.TILE_LAYOUT),
(4, 2, [1, 4, 64, 2304], 1, ttnn.TILE_LAYOUT),
(4, 2, [1, 4, 64, 6656], 1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
"input_dtype",
[
ttnn.bfloat16,
],
)
@pytest.mark.parametrize(
"buffer_type",
[
ttnn.BufferType.DRAM,
ttnn.BufferType.L1,
],
)
@pytest.mark.parametrize("replication_factor", [8])
@pytest.mark.parametrize("enable_async", [True])
@pytest.mark.parametrize("num_iters", [20])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True)
def test_line_reduce_scatter_on_TG_rows_post_commit(
mesh_device,
num_devices,
per_chip_output_shape,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async,
replication_factor,
num_iters,
):
run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices,
per_chip_output_shape,
ttnn.TensorMemoryLayout.INTERLEAVED,
dim,
num_links,
math_op,
input_dtype,
layout,
buffer_type,
use_program_cache,
function_level_defaults,
enable_async=enable_async,
num_iters=num_iters,
num_reduce_scatter_instances=replication_factor,
cluster_axis=1,
trace_mode=True,
)
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,61 @@ def print_tile_corners_of_tensor(t):
print(f"{str_vals}")


def run_with_trace(
mesh_device,
all_gather_topology,
input_tensor,
scatter_dim,
num_links,
cluster_axis,
output_mem_config,
n_worker=None,
n_buffer=None,
num_iter=20,
):
# Compile Run
logger.info("Compiling model")
tt_out_tensor = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=scatter_dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Capture trace
logger.info("Capturing trace")
trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0)
for i in range(num_iter):
tt_out_tensor = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=scatter_dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# Run the op
logger.info("Starting Trace perf test...")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

return tt_out_tensor


def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
mesh_device,
num_devices_per_line,
Expand All @@ -63,6 +118,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
num_reduce_scatter_instances: int = 1,
num_iters: int = 1,
cluster_axis: int = 0,
trace_mode=False,
):
if len(mesh_device.get_devices()) != 32:
pytest.skip("Not TG!")
Expand Down Expand Up @@ -163,18 +219,24 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
mesh_device=mesh_device,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
output_mem_config=output_mem_config,
all_gather_topology=ttnn.Topology.Linear,
num_iter=num_iters,
)
ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

logger.info("Starting Trace perf test...")
ttnn.execute_trace(mesh_device, trace_id, blocking=False)
ttnn.release_trace(mesh_device, trace_id)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)
else:
for _ in range(num_iters):
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
num_links=num_links,
memory_config=output_mem_config,
topology=ttnn.Topology.Linear,
)
for d in mesh_device.get_devices():
ttnn.synchronize_device(d)

# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor_out)
tt_output_tensor = ttnn.to_torch(
Expand Down Expand Up @@ -290,7 +352,6 @@ def test_line_reduce_scatter_on_TG_rows_post_commit(
@pytest.mark.parametrize("replication_factor", [4])
@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True)
@pytest.mark.parametrize("math_op", [ttnn.ReduceType.Sum])
@pytest.mark.parametrize("device_params", [{"trace_region_size": 10281600}], indirect=True)
def test_line_reduce_scatter_on_TG_cols_post_commit(
mesh_device,
num_devices,
Expand Down

0 comments on commit b53b929

Please sign in to comment.