diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 9a51aad2a74..927c0a6ed82 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -357,7 +357,7 @@ def forward_decode( if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( dense_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -530,7 +530,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( output_11SH, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index f06e2ff63f1..88b44927715 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -137,7 +137,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if self.args.is_multichip: w2_out_reduced = ttnn.reduce_scatter( w2_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 63f87fbeb73..fb9266c23a1 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -271,7 +271,7 @@ def forward_decode(self, x_11SH, xattn_mask, full_text_row_masked_out_mask_1NSH, if self.is_multichip: dense_out_reduced = ttnn.reduce_scatter( output, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -358,7 +358,7 @@ def forward_prefill( if self.is_multichip: # TODO use_fused_all_gather_matmul dense_out_reduced = ttnn.reduce_scatter( output, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/qwen/tt/qwen_attention.py b/models/demos/qwen/tt/qwen_attention.py index ba598cc96c1..0e80c47b228 100644 --- a/models/demos/qwen/tt/qwen_attention.py +++ b/models/demos/qwen/tt/qwen_attention.py @@ -414,7 +414,7 @@ def forward_decode( if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( dense_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.L1_MEMORY_CONFIG, @@ -598,7 +598,7 @@ def forward_prefill(self, x_11SH, rot_mats, transformation_mats, user_id: int = if self.is_multichip and not self.use_fused_all_gather_matmul: dense_out_reduced = ttnn.reduce_scatter( output_11SH, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, diff --git a/models/demos/qwen/tt/qwen_mlp.py b/models/demos/qwen/tt/qwen_mlp.py index e07d4943d1c..ad500853920 100644 --- a/models/demos/qwen/tt/qwen_mlp.py +++ b/models/demos/qwen/tt/qwen_mlp.py @@ -142,7 +142,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: if self.args.is_multichip: w2_out_reduced = ttnn.reduce_scatter( w2_out, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG if mode == "prefill" else ttnn.L1_MEMORY_CONFIG, diff --git a/models/demos/t3000/falcon40b/tt/falcon_mlp.py b/models/demos/t3000/falcon40b/tt/falcon_mlp.py index 1788c3ac6b6..5101b309d4d 100644 --- a/models/demos/t3000/falcon40b/tt/falcon_mlp.py +++ b/models/demos/t3000/falcon40b/tt/falcon_mlp.py @@ -124,7 +124,7 @@ def fwd_decode(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states = ttnn.get_device_tensors( ttnn.reduce_scatter( ttnn.aggregate_as_tensor(hidden_states), - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, # only unidirectional supported for now memory_config=self.model_config["DEFAULT_MEMCFG"], @@ -200,7 +200,7 @@ def fwd_prefill(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states = ttnn.get_device_tensors( ttnn.reduce_scatter( ttnn.aggregate_as_tensor(hidden_states), - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, # only one link supported for now memory_config=self.model_config["DEFAULT_MEMCFG"], diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 2861253da1a..a185fc605f0 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -219,7 +219,7 @@ def prefill_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states_reduced = ttnn.reduce_scatter( hidden_states_mm, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG, @@ -268,7 +268,7 @@ def decode_forward(self, x: List[ttnn.Tensor]) -> List[ttnn.Tensor]: hidden_states_reduced = ttnn.reduce_scatter( hidden_states, - scatter_dim=3, + dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, memory_config=self.model_config["RESIDUAL_16_CORES_OUTPUT_MEMCFG"], diff --git a/models/demos/tg/llama3_70b/tt/llama_common.py b/models/demos/tg/llama3_70b/tt/llama_common.py index 1b16fde6a60..9824afbc44c 100644 --- a/models/demos/tg/llama3_70b/tt/llama_common.py +++ b/models/demos/tg/llama3_70b/tt/llama_common.py @@ -93,7 +93,7 @@ def tt_composite_sharded_all_reduce( input_mem_cfg = input_tensor.memory_config() reduce_scattered_tensor = ttnn.reduce_scatter( input_tensor, - scatter_dim=dim, + dim=dim, math_op=ttnn.ReduceType.Sum, num_links=num_links, cluster_axis=cluster_axis, diff --git a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py index 1429eb0fce1..800d25befb8 100644 --- a/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py +++ b/tests/ttnn/unit_tests/operations/ccl/perf/test_ccl_perf.py @@ -141,7 +141,7 @@ def test_all_gather_on_t3000( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), @@ -171,7 +171,7 @@ def test_reduce_scatter_on_t3000( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -187,7 +187,7 @@ def test_reduce_scatter_on_t3000( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -210,7 +210,7 @@ def test_reduce_scatter_on_t3000( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), @@ -239,7 +239,7 @@ def test_reduce_scatter_on_n300( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -254,7 +254,7 @@ def test_reduce_scatter_on_n300( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py index ad1d7a63abe..3313d73880c 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_all_gather_nightly.py @@ -22,15 +22,15 @@ [ (4, 1, [4, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), (8, 1, [8, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT), - (8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT), + (8, 1, [8, 1, 256, 32], -4, ttnn.TILE_LAYOUT), (8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT), # (4, 2, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), (8, 1, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT), - (4, 1, [8, 5, 13, 384], 3, ttnn.ROW_MAJOR_LAYOUT), - (8, 1, [8, 5, 13, 512], 3, ttnn.ROW_MAJOR_LAYOUT), + (4, 1, [8, 5, 13, 384], -1, ttnn.ROW_MAJOR_LAYOUT), + (8, 1, [8, 5, 13, 512], -1, ttnn.ROW_MAJOR_LAYOUT), (4, 1, [8, 5, 32, 384], 3, ttnn.TILE_LAYOUT), (8, 1, [8, 5, 32, 512], 3, ttnn.TILE_LAYOUT), - (4, 1, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT), + (4, 1, [1, 1, 32, 16384], -1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py index c34c4fd6191..086efb1d534 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_N300_post_commit.py @@ -20,7 +20,7 @@ ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 4096], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), @@ -50,7 +50,7 @@ def test_ring_reduce_scatter_n300_post_commit( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -65,7 +65,7 @@ def test_ring_reduce_scatter_n300_post_commit( n300_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py index 9e9fbf479f5..1b5bfe8f672 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_TG_nightly.py @@ -145,7 +145,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( # ttnn.visualize_mesh_device(mesh_device, tensor=ttnn_tensor) ttnn_tensor_out = ttnn.reduce_scatter( ttnn_tensor, - scatter_dim=dim, + dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, math_op=math_op, @@ -158,7 +158,7 @@ def run_line_reduce_scatter_on_TG_with_mesh_tensor_along_rows( for _ in range(num_iters): ttnn_tensor_out = ttnn.reduce_scatter( ttnn_tensor, - scatter_dim=dim, + dim=dim, cluster_axis=cluster_axis, mesh_device=mesh_device, math_op=math_op, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py index 17eee107972..aaf8e21fc10 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_nightly.py @@ -19,23 +19,23 @@ ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 64], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 32], -1, ttnn.TILE_LAYOUT), + ([1, 1, 32, 64], -1, ttnn.TILE_LAYOUT), ([1, 1, 64, 64], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 128], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 128], -1, ttnn.TILE_LAYOUT), ([1, 1, 32, 256], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), - ([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 1, 128, 1024], -1, ttnn.TILE_LAYOUT), ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), ([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT), + ([1, 1, 2048, 8192], -1, ttnn.TILE_LAYOUT), ], ) @pytest.mark.parametrize( @@ -58,7 +58,7 @@ def test_reduce_scatter_t3k_8chip_nightly( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -73,7 +73,7 @@ def test_reduce_scatter_t3k_8chip_nightly( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -95,16 +95,16 @@ def test_reduce_scatter_t3k_8chip_nightly( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT), ([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT), + ([1, 4, 2048, 1024], -1, ttnn.TILE_LAYOUT), ([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT), + ([1, 1, 32, 2048], -1, ttnn.TILE_LAYOUT), ([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT), - ([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT), + ([1, 1, 128, 8192], -1, ttnn.TILE_LAYOUT), ([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT), ([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT), # These shapes result in some workers with no work, which is currently @@ -136,7 +136,7 @@ def test_reduce_scatter_t3k_4chip_nightly( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -151,7 +151,7 @@ def test_reduce_scatter_t3k_4chip_nightly( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, diff --git a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py index 916682dd84e..4efe5152448 100644 --- a/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py +++ b/tests/ttnn/unit_tests/operations/ccl/test_reduce_scatter_post_commit.py @@ -10,7 +10,7 @@ from models.utility_functions import skip_for_grayskull -def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): +def is_unsupported_case(input_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout): elem_size = 2 if input_dtype == ttnn.bfloat16 else 1 tensor_size_bytes = elem_size for i in input_shape: @@ -19,7 +19,7 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devic if mem_config.buffer_type == ttnn.BufferType.L1 and tensor_size_bytes > num_l1_banks * 50 * 1024: return True, "L1 buffer can't support large tensor sizes" - # if input_dtype == ttnn.bfloat8_b and tuple(input_shape) == (1, 1, 2048, 1024) and scatter_dim == 3: + # if input_dtype == ttnn.bfloat8_b and tuple(input_shape) == (1, 1, 2048, 1024) and dim == 3: # return True, "Known failure with bfp8_b data format" return False, "" @@ -28,7 +28,7 @@ def is_unsupported_case(input_shape, scatter_dim, math_op, mem_config, num_devic def run_with_trace( t3k_mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, output_mem_config, @@ -41,7 +41,7 @@ def run_with_trace( logger.info("Compiling model") output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -58,7 +58,7 @@ def run_with_trace( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -84,7 +84,7 @@ def run_reduce_scatter_test( mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -105,7 +105,7 @@ def run_reduce_scatter_test( debug = False (is_known_failure, message) = is_unsupported_case( - per_chip_output_shape, scatter_dim, math_op, mem_config, num_devices, num_links, input_dtype, layout + per_chip_output_shape, dim, math_op, mem_config, num_devices, num_links, input_dtype, layout ) if is_known_failure: pytest.skip(f"Skipping unsupported case {message}.") @@ -114,11 +114,11 @@ def run_reduce_scatter_test( if enable_async: logger.info(f"Using Async Mode for Reduce Scatter Op Dispatch") - logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}") + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, dim: {dim}") # Generate input tensors canonical_input_shape = per_chip_output_shape.copy() - canonical_input_shape[scatter_dim] *= num_devices + canonical_input_shape[dim] *= num_devices tt_input_tensors = [] numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] @@ -143,7 +143,7 @@ def run_reduce_scatter_test( output_tensor_mesh = run_with_trace( mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, mem_config, @@ -154,7 +154,7 @@ def run_reduce_scatter_test( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=mem_config, @@ -172,7 +172,7 @@ def run_reduce_scatter_test( for i, t in enumerate(input_tensors): golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() - golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, scatter_dim) + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, dim) tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) logger.info(f"Compare") @@ -211,7 +211,7 @@ def run_reduce_scatter_test( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 2, 256, 32 * 8], 3, ttnn.TILE_LAYOUT), # Input tensor is (16*32) x (64*32) = 8 * input tensor shape ([1, 1, 32, 32 * 8], 3, ttnn.TILE_LAYOUT), @@ -241,7 +241,7 @@ def test_ring_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -256,7 +256,7 @@ def test_ring_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -279,7 +279,7 @@ def test_ring_reduce_scatter_post_commit( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 32 * 8], 3, ttnn.TILE_LAYOUT), ([1, 2, 224, 32 * 8], 3, ttnn.TILE_LAYOUT), @@ -306,7 +306,7 @@ def test_line_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -321,7 +321,7 @@ def test_line_reduce_scatter_post_commit( t3k_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -345,7 +345,7 @@ def test_line_reduce_scatter_post_commit( ], ) @pytest.mark.parametrize( - "per_chip_output_shape, scatter_dim, layout", + "per_chip_output_shape, dim, layout", [ ([1, 1, 32, 1280], 1, ttnn.TILE_LAYOUT), ([1, 1, 32, 1024], 1, ttnn.TILE_LAYOUT), @@ -369,7 +369,7 @@ def test_line_reduce_scatter_post_commit_4chip( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -384,7 +384,7 @@ def test_line_reduce_scatter_post_commit_4chip( pcie_mesh_device, num_devices, per_chip_output_shape, - scatter_dim, + dim, num_links, math_op, input_dtype, @@ -403,7 +403,7 @@ def run_reduce_scatter_sharded_test( num_devices, per_chip_output_shape, output_shard_shape, - scatter_dim, + dim, num_links, math_op, shard_grid, @@ -427,7 +427,7 @@ def run_reduce_scatter_sharded_test( f"Not enough devices on machine to implement test case. Wanted {num_devices} but found {len(t3k_mesh_device.get_device_ids())}" ) - logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, scatter_dim: {scatter_dim}") + logger.info(f"Per chip output shape: {per_chip_output_shape}, devices: {num_devices}, dim: {dim}") debug = False @@ -438,7 +438,7 @@ def run_reduce_scatter_sharded_test( assert in_shard_override is None in_shard_grid = shard_grid input_shard_shape = list(output_shard_shape) - if scatter_dim == 3: + if dim == 3: input_shard_shape[1] *= num_devices else: input_shard_shape[0] *= num_devices @@ -468,7 +468,7 @@ def run_reduce_scatter_sharded_test( ) canonical_input_shape = list(per_chip_output_shape) - canonical_input_shape[scatter_dim] *= num_devices + canonical_input_shape[dim] *= num_devices numel = canonical_input_shape[0] * canonical_input_shape[1] * canonical_input_shape[2] * canonical_input_shape[3] input_tensors = [ @@ -492,7 +492,7 @@ def run_reduce_scatter_sharded_test( output_tensor_mesh = run_with_trace( t3k_mesh_device, input_tensor_mesh, - scatter_dim, + dim, num_links, math_op, output_mem_config, @@ -504,7 +504,7 @@ def run_reduce_scatter_sharded_test( for i in range(num_iters): output_tensor_mesh = ttnn.reduce_scatter( input_tensor_mesh, - scatter_dim=scatter_dim, + dim=dim, math_op=math_op, num_links=num_links, memory_config=output_mem_config, @@ -521,7 +521,7 @@ def run_reduce_scatter_sharded_test( for i, t in enumerate(input_tensors): golden_canonical_out_tensor = torch.add(golden_canonical_out_tensor, t).bfloat16() - golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, scatter_dim) + golden_output_tensors = torch.chunk(golden_canonical_out_tensor, num_devices, dim) tt_out_tensors = ttnn.get_device_tensors(output_tensor_mesh) logger.info(f"Compare") diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp index 63983bd9f01..34f067df23d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp @@ -9,7 +9,7 @@ namespace ttnn::operations::ccl { ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional num_workers, @@ -21,7 +21,7 @@ ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor, ttnn::Tensor ExecuteAllGather::invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp index 1816d4c083d..541335982fa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp @@ -14,7 +14,7 @@ namespace ccl { struct ExecuteAllGather { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const std::optional num_workers = std::nullopt, @@ -23,7 +23,7 @@ struct ExecuteAllGather { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links = 1, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp index 8937ced1230..19de7aa652b 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp @@ -29,7 +29,7 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional num_workers, @@ -49,7 +49,7 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation, ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 32fc7afb01a..81b295dd3fc 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -175,7 +175,7 @@ namespace operations { namespace ccl { Tensor all_gather( - const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { + const Tensor& input_tensor, const int32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch"); auto devices = input_tensor.get_workers(); @@ -186,9 +186,16 @@ Tensor all_gather( if (num_devices == 2){ ccl_topology = ttnn::ccl::Topology::Linear; } + + int32_t rank = input_tensor.get_logical_shape().rank(); + + int32_t gather_dim = (dim < 0) ? rank + dim : dim; + + TT_FATAL(gather_dim >= -rank && gather_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology]( + [gather_dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -196,7 +203,7 @@ Tensor all_gather( const auto& input_tensor = input_tensors.at(0); return operation::run( - ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology), + ttnn::ccl::all_gather_detail::create_all_gather_struct(input_tensor, gather_dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology), {input_tensor}); }, {input_tensor}, @@ -206,7 +213,7 @@ Tensor all_gather( Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links, @@ -219,10 +226,16 @@ Tensor all_gather( const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + int32_t rank = input_tensor.get_logical_shape().rank(); + + int32_t gather_dim = (dim < 0) ? rank + dim : dim; + + TT_FATAL(gather_dim >= -rank && gather_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( - [dim, num_links, memory_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology]( + [gather_dim, num_links, memory_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology]( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& optional_output_tensors) mutable -> std::vector { @@ -250,7 +263,7 @@ Tensor all_gather( return operation::run( ttnn::AllGather{ - dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology}, + gather_dim, num_links, num_devices, device_index, user_defined_num_workers, user_defined_num_buffers_per_channel, receiver_device_id, sender_device_id, memory_config.value_or(input_device_tensor.memory_config()), topology}, {input_device_tensor}); }, {input_tensor}, diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index b0a162f2a1f..abc697dfab5 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -200,7 +200,7 @@ namespace ccl { Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, const std::optional user_defined_num_workers = std::nullopt, @@ -209,7 +209,7 @@ Tensor all_gather( Tensor all_gather( const Tensor& input_tensor, - const uint32_t dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, const uint32_t num_links = 1, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 0924001d006..44ee7916127 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -107,7 +107,7 @@ namespace operations{ namespace ccl{ Tensor reduce_scatter( const Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const MemoryConfig& output_mem_config, @@ -126,6 +126,12 @@ Tensor reduce_scatter( ccl_topology = ttnn::ccl::Topology::Linear; } + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t scatter_dim = (dim < 0) ? rank + dim : dim; + + TT_FATAL(scatter_dim >= -rank && scatter_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( [binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel]( @@ -158,7 +164,7 @@ Tensor reduce_scatter( Tensor reduce_scatter( const Tensor &input_tensor, - const uint32_t scatter_dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op, @@ -174,6 +180,12 @@ Tensor reduce_scatter( const auto mesh_view = mesh_device.get_view(); std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols(); + int16_t rank = input_tensor.get_logical_shape().rank(); + + int16_t scatter_dim = (dim < 0) ? rank + dim : dim; + + TT_FATAL(scatter_dim >= -rank && scatter_dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; operation::launch_op( diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index f26107cda30..57f5d055caa 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -69,7 +69,7 @@ namespace operations{ namespace ccl{ Tensor reduce_scatter( const Tensor &input_tensor, - const uint32_t scatter_split_dim, + const int32_t dim, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, const uint32_t num_links = 1, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, @@ -79,7 +79,7 @@ Tensor reduce_scatter( Tensor reduce_scatter( const ttnn::Tensor &input_tensor, - const uint32_t scatter_dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp index ea28f4bd932..027b159d8f8 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.cpp @@ -10,7 +10,7 @@ namespace ttnn::operations::ccl { ttnn::Tensor ExecuteReduceScatter::invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const std::optional& memory_config, @@ -19,11 +19,11 @@ ttnn::Tensor ExecuteReduceScatter::invoke( const std::optional num_buffers_per_channel) { MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config()); - return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel); + return ttnn::operations::ccl::reduce_scatter(input_tensor, dim, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel); } ttnn::Tensor ExecuteReduceScatter::invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType math_op, @@ -34,7 +34,7 @@ ttnn::Tensor ExecuteReduceScatter::invoke( const std::optional num_buffers_per_channel) { MemoryConfig out_memory_config = memory_config.value_or(input_tensor.memory_config()); - return ttnn::operations::ccl::reduce_scatter(input_tensor, scatter_dim, cluster_axis, mesh_device, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel); + return ttnn::operations::ccl::reduce_scatter(input_tensor, dim, cluster_axis, mesh_device, math_op, num_links, out_memory_config, topology, num_workers, num_buffers_per_channel); } } // namespace ttnn::operations::ccl diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp index b7acc80e794..044af18777c 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter.hpp @@ -17,7 +17,7 @@ namespace ccl { struct ExecuteReduceScatter { static ttnn::Tensor invoke( const Tensor &input_tensor, - const uint32_t scatter_dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum, @@ -29,7 +29,7 @@ struct ExecuteReduceScatter { static ttnn::Tensor invoke( const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links = 1, const std::optional& memory_config = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp index bfac2f9a1d1..011c217ff5a 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/reduce_scatter_pybind.cpp @@ -26,17 +26,17 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, ttnn::operations::reduction::ReduceType math_op, const uint32_t num_links, const ttnn::MemoryConfig& memory_config, ttnn::ccl::Topology topology, const std::optional num_workers, const std::optional num_buffers_per_channel) -> ttnn::Tensor { - return self(input_tensor, scatter_dim, math_op, num_links, memory_config, topology, num_workers, num_buffers_per_channel); + return self(input_tensor, dim, math_op, num_links, memory_config, topology, num_workers, num_buffers_per_channel); }, py::arg("input_tensor"), - py::arg("scatter_dim"), + py::arg("dim"), py::arg("math_op"), py::kw_only(), py::arg("num_links") = 1, @@ -48,7 +48,7 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat ttnn::pybind_overload_t{ [](const ccl_operation_t& self, const ttnn::Tensor& input_tensor, - const uint32_t scatter_dim, + const int32_t dim, const uint32_t cluster_axis, const MeshDevice& mesh_device, ttnn::operations::reduction::ReduceType math_op, @@ -57,10 +57,10 @@ void bind_reduce_scatter(pybind11::module& module, const ccl_operation_t& operat const std::optional num_workers, const std::optional num_buffers_per_channel, const ttnn::ccl::Topology topology) -> ttnn::Tensor { - return self(input_tensor, scatter_dim, cluster_axis, mesh_device, math_op, num_links, output_mem_config, topology, num_workers, num_buffers_per_channel); + return self(input_tensor, dim, cluster_axis, mesh_device, math_op, num_links, output_mem_config, topology, num_workers, num_buffers_per_channel); }, py::arg("input_tensor"), - py::arg("scatter_dim"), + py::arg("dim"), py::arg("cluster_axis"), py::arg("mesh_device"), py::arg("math_op"),