diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 92363033fbb..706ea8a7fe2 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -479,6 +479,7 @@ CCL ttnn.all_gather ttnn.reduce_scatter + ttnn.experimental.all_reduce Embedding diff --git a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py index bd56a03a62b..61de4ba164c 100644 --- a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py +++ b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py @@ -8,12 +8,10 @@ import ttnn -from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time from loguru import logger -import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case_n300 -from ttnn import ShardTensorToMesh # Override the default timeout in seconds for hang detection. TIMEOUT = 30 diff --git a/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py b/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py index d57fe770d7a..710ccb9f21a 100644 --- a/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py +++ b/tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py @@ -10,9 +10,7 @@ from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time from loguru import logger -import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case_n300 # Override the default timeout in seconds for hang detection. TIMEOUT = 30 diff --git a/tests/sweep_framework/sweeps/ccl/line_all_gather.py b/tests/sweep_framework/sweeps/ccl/line_all_gather.py index 5973221e337..c99da0c0f35 100644 --- a/tests/sweep_framework/sweeps/ccl/line_all_gather.py +++ b/tests/sweep_framework/sweeps/ccl/line_all_gather.py @@ -8,9 +8,8 @@ import ttnn -from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time from loguru import logger -import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case from ttnn import ShardTensorToMesh diff --git a/tests/ttnn/unit_tests/operations/test_all_gather.py b/tests/ttnn/unit_tests/operations/test_all_gather.py index 233c2b6bfe1..38bcdaac987 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather.py @@ -8,8 +8,6 @@ import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull -import itertools -from ttnn import ShardTensorToMesh def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout): diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py b/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py index fbb80635469..9ec6de69203 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_N300_post_commit.py @@ -2,11 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 -import torch import pytest -from loguru import logger import ttnn -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.test_all_gather import ( run_all_gather_on_n300_impl, diff --git a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py index fda60985e5a..cfdf32d6a76 100644 --- a/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_all_gather_nightly.py @@ -7,7 +7,7 @@ from loguru import logger import ttnn from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.test_all_gather import ( is_unsupported_case, run_all_gather_on_t3000_impl, diff --git a/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py b/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py index 3e75367d05e..38e20e9d344 100644 --- a/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py +++ b/tests/ttnn/unit_tests/operations/test_all_reduce_t3000_frequent.py @@ -6,8 +6,8 @@ import pytest from loguru import logger import ttnn -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc -from models.utility_functions import skip_for_grayskull, get_devices_for_t3000 +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import skip_for_grayskull def is_unsupported_case(input_shape, math_op, mem_config, num_devices, num_links, input_dtype, layout): @@ -159,6 +159,7 @@ def run_all_reduce_test( assert not mismatch, f"{i} FAILED: {output}" +@skip_for_grayskull("Requires eth connected devices to run") @pytest.mark.timeout(120) @pytest.mark.parametrize( "num_devices, num_links", diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_N300_post_commit.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_N300_post_commit.py index b96530d7c1f..44604c19d25 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_N300_post_commit.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_N300_post_commit.py @@ -2,9 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 -import torch import pytest -from loguru import logger import ttnn from models.utility_functions import skip_for_grayskull from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import ( diff --git a/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py index 8b886e1d6ab..26764654088 100644 --- a/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py +++ b/tests/ttnn/unit_tests/operations/test_reduce_scatter_nightly.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import torch import pytest import ttnn from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import ( diff --git a/ttnn/cpp/ttnn/operations/ccl/README.md b/ttnn/cpp/ttnn/operations/ccl/README.md index de1277a6ccf..7cc584c3414 100644 --- a/ttnn/cpp/ttnn/operations/ccl/README.md +++ b/ttnn/cpp/ttnn/operations/ccl/README.md @@ -12,6 +12,7 @@ all-reduce), allocation * All Gather * Reduce Scatter +* All Reduce (experimental) ### Configurations For the time being, input and output configurations are expected to match @@ -36,7 +37,7 @@ Future Tologies: * Torus (2d, 3d) # Future Operations -* All Reduce +* All Reduce (Full support) * Reduce * Scatter * Gather Scatter diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index ab1993d03bb..e7a0fd0f9bd 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -32,20 +32,27 @@ AllGather create_all_gather_struct( for (uint32_t i = 0; i < num_devices; ++i) { if (devices[i] == input_tensor.device()) { device_index = i; - if (topology == ttnn::ccl::Topology::Ring) { - // Ring topology - receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring - sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring - } else if (topology == ttnn::ccl::Topology::Linear) { - // Linear topology - bool is_last_chip_in_clockwise_direction = i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = i == 0; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at(i+1)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at(i-1)->id()); + switch(topology){ + case ttnn::ccl::Topology::Ring:{ + // Ring topology + receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring + sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring + break; + } + case ttnn::ccl::Topology::Linear:{ + // Linear topology + bool is_last_chip_in_clockwise_direction = i == (num_devices - 1); + bool is_last_chip_in_counter_clockwise_direction = i == 0; + receiver_device_id = is_last_chip_in_clockwise_direction ? + std::nullopt : + std::optional(devices.at(i+1)->id()); + sender_device_id = is_last_chip_in_counter_clockwise_direction ? + std::nullopt : + std::optional(devices.at(i-1)->id()); + break; + } + default: + TT_FATAL(false, "Invalid Topology {}, Accepted topologies are Ring and Linear currently", topology); } break; } @@ -136,7 +143,7 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu void AllGather::validate(const std::vector &input_tensors) const { - TT_FATAL(input_tensors.size() == 1, "Error"); + TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size()); const auto& input_tensor = input_tensors[0]; const auto& layout = input_tensors[0].get_layout(); const auto& dtype = input_tensors[0].get_dtype(); @@ -148,9 +155,9 @@ void AllGather::validate(const std::vector &input_tensors) const { // TODO: Validate ring TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!"); TT_FATAL(input_tensor.buffer() != nullptr , "Operands to all_gather need to be allocated in buffers on device!"); - TT_FATAL(this->num_links > 0, "Error"); + TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links); TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows"); - TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error"); + TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error, All-gather was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid all-gather configuration. The input mesh tensor or all-gather arguments may be incorrect"); TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED || input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED || @@ -196,7 +203,7 @@ namespace ccl { Tensor all_gather( const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional& memory_config, const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) { - TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch"); auto devices = input_tensor.get_workers(); uint32_t num_devices = devices.size(); ttnn::ccl::Topology ccl_topology = topology; diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index 6288ae6d0f9..2c87dd4dd00 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -75,7 +75,7 @@ Tensor reduce_scatter( const std::optional user_defined_num_workers, const std::optional user_defined_num_buffers_per_channel) { ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op); - TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch"); + TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch"); auto devices = input_tensor.get_workers(); std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; @@ -110,7 +110,7 @@ Tensor reduce_scatter( break; } } - TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in reduce scatter op setup"); + TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect"); return operation::run( ttnn::ReduceScatter{ diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp index 2a04802bca0..ca82d3ba307 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_reduce/device/all_reduce_op.cpp @@ -83,28 +83,8 @@ Tensor all_reduce( bool is_linear = topology == ttnn::ccl::Topology::Linear; const auto& input_tensor = input_tensors.at(0); - uint32_t num_devices = devices.size(); - uint32_t device_index = 0; // Initialize device index - std::optional receiver_device_id = std::nullopt; // Initialize receiver device ID - std::optional sender_device_id = std::nullopt; // Initialize sender device ID - for (uint32_t i = 0; i < num_devices; ++i) { - if (devices.at(i) == input_tensor.device()) { - - bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1); - bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0; - device_index = i; - receiver_device_id = is_last_chip_in_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + 1) % num_devices)->id()); - sender_device_id = is_last_chip_in_counter_clockwise_direction ? - std::nullopt : - std::optional(devices.at((i + num_devices - 1) % num_devices)->id()); - break; - } - } - TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in all reduce op setup"); - auto shape = input_tensor.shape(); + auto shape = input_tensor.get_logical_shape(); auto rank = shape.rank(); uint32_t merged_dim_size = 1;