Skip to content

Commit

Permalink
[CCL] Add negative dim support (#15305)
Browse files Browse the repository at this point in the history
### Ticket
#15288 

### Problem description
Negative dim input throws error

### What's changed
Adds negative dim support across ccl ops

### Checklist
- [ ] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [ ] New/Existing tests provide coverage for changes
  • Loading branch information
Aswinmcw authored and spoojaryTT committed Nov 25, 2024
1 parent f73dc7f commit aeca52f
Show file tree
Hide file tree
Showing 24 changed files with 125 additions and 100 deletions.
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/llama_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def forward_decode(
if self.is_multichip and not self.use_fused_all_gather_matmul:
dense_out_reduced = ttnn.reduce_scatter(
dense_out,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.L1_MEMORY_CONFIG,
Expand Down Expand Up @@ -530,7 +530,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int =
if self.is_multichip and not self.use_fused_all_gather_matmul:
dense_out_reduced = ttnn.reduce_scatter(
output_11SH,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/llama3/tt/llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
if self.args.is_multichip:
w2_out_reduced = ttnn.reduce_scatter(
w2_out,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG,
Expand Down
4 changes: 2 additions & 2 deletions models/demos/llama3/tt/multimodal/llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH,
if self.is_multichip:
dense_out_reduced = ttnn.reduce_scatter(
output,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.L1_MEMORY_CONFIG,
Expand Down Expand Up @@ -358,7 +358,7 @@ def forward_prefill(
if self.is_multichip: # TODO use_fused_all_gather_matmul
dense_out_reduced = ttnn.reduce_scatter(
output,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
Expand Down
4 changes: 2 additions & 2 deletions models/demos/qwen/tt/qwen_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def forward_decode(
if self.is_multichip and not self.use_fused_all_gather_matmul:
dense_out_reduced = ttnn.reduce_scatter(
dense_out,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.L1_MEMORY_CONFIG,
Expand Down Expand Up @@ -598,7 +598,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int =
if self.is_multichip and not self.use_fused_all_gather_matmul:
dense_out_reduced = ttnn.reduce_scatter(
output_11SH,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
Expand Down
2 changes: 1 addition & 1 deletion models/demos/qwen/tt/qwen_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor:
if self.args.is_multichip:
w2_out_reduced = ttnn.reduce_scatter(
w2_out,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG,
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def fwd_decode(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:
hidden_states = ttnn.get_device_tensors(
ttnn.reduce_scatter(
ttnn.aggregate_as_tensor(hidden_states),
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1, # only unidirectional supported for now
memory_config=self.model_config["DEFAULT_MEMCFG"],
Expand Down Expand Up @@ -200,7 +200,7 @@ def fwd_prefill(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:
hidden_states = ttnn.get_device_tensors(
ttnn.reduce_scatter(
ttnn.aggregate_as_tensor(hidden_states),
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1, # only one link supported for now
memory_config=self.model_config["DEFAULT_MEMCFG"],
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:

hidden_states_reduced = ttnn.reduce_scatter(
hidden_states_mm,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
Expand Down Expand Up @@ -268,7 +268,7 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]:

hidden_states_reduced = ttnn.reduce_scatter(
hidden_states,
scatter_dim=3,
dim=3,
math_op=ttnn.ReduceType.Sum,
num_links=1,
memory_config=self.model_config["RESIDUAL_16_CORES_OUTPUT_MEMCFG"],
Expand Down
2 changes: 1 addition & 1 deletion models/demos/tg/llama3_70b/tt/llama_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def tt_composite_sharded_all_reduce(
input_mem_cfg = input_tensor.memory_config()
reduce_scattered_tensor = ttnn.reduce_scatter(
input_tensor,
scatter_dim=dim,
dim=dim,
math_op=ttnn.ReduceType.Sum,
num_links=num_links,
cluster_axis=cluster_axis,
Expand Down
12 changes: 6 additions & 6 deletions tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_all_gather_on_t3000(
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
"per_chip_output_shape, dim, layout",
[
([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT),
Expand Down Expand Up @@ -171,7 +171,7 @@ def test_reduce_scatter_on_t3000(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -187,7 +187,7 @@ def test_reduce_scatter_on_t3000(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -210,7 +210,7 @@ def test_reduce_scatter_on_t3000(
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
"per_chip_output_shape, dim, layout",
[
([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
Expand Down Expand Up @@ -239,7 +239,7 @@ def test_reduce_scatter_on_n300(
n300_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -254,7 +254,7 @@ def test_reduce_scatter_on_n300(
n300_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
[
(4, 1, [4, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT),
(8, 1, [8, 1, 256, 32], -4, ttnn.TILE_LAYOUT),
(8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT),
# (4, 2, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT),
(8, 1, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT),
(4, 1, [8, 5, 13, 384], 3, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 5, 13, 512], 3, ttnn.ROW_MAJOR_LAYOUT),
(4, 1, [8, 5, 13, 384], -1, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 5, 13, 512], -1, ttnn.ROW_MAJOR_LAYOUT),
(4, 1, [8, 5, 32, 384], 3, ttnn.TILE_LAYOUT),
(8, 1, [8, 5, 32, 512], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 16384], -1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
"per_chip_output_shape, dim, layout",
[
([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_ring_reduce_scatter_n300_post_commit(
n300_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -65,7 +65,7 @@ def test_ring_reduce_scatter_n300_post_commit(
n300_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
# ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor)
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=dim,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
Expand All @@ -158,7 +158,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows(
for _ in range(num_iters):
ttnn_tensor_out = ttnn.reduce_scatter(
ttnn_tensor,
scatter_dim=dim,
dim=dim,
cluster_axis=cluster_axis,
mesh_device=mesh_device,
math_op=math_op,
Expand Down
28 changes: 14 additions & 14 deletions tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,23 @@
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
"per_chip_output_shape, dim, layout",
[
([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 64], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 32], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 64], -1, ttnn.TILE_LAYOUT),
([1, 1, 64, 64], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 128], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 128], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 256], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], -1, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], -1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand All @@ -58,7 +58,7 @@ def test_reduce_scatter_t3k_8chip_nightly(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -73,7 +73,7 @@ def test_reduce_scatter_t3k_8chip_nightly(
t3k_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -95,16 +95,16 @@ def test_reduce_scatter_t3k_8chip_nightly(
],
)
@pytest.mark.parametrize(
"per_chip_output_shape, scatter_dim, layout",
"per_chip_output_shape, dim, layout",
[
([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], -1, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], -1, ttnn.TILE_LAYOUT),
([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT),
# These shapes result in some workers with no work, which is currently
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_reduce_scatter_t3k_4chip_nightly(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand All @@ -151,7 +151,7 @@ def test_reduce_scatter_t3k_4chip_nightly(
pcie_mesh_device,
num_devices,
per_chip_output_shape,
scatter_dim,
dim,
num_links,
math_op,
input_dtype,
Expand Down
Loading

0 comments on commit aeca52f

Please sign in to comment.