Skip to content

Commit

Permalink
#0: CCL cleanup/fix (#14030)
Browse files Browse the repository at this point in the history
* #0: CCL cleanup/fix

* #0: CCL Improve error messages

* #0: Report passed values
  • Loading branch information
Aswinmcw authored Oct 22, 2024
1 parent 46e2f9b commit 204e3aa
Show file tree
Hide file tree
Showing 14 changed files with 37 additions and 60 deletions.
1 change: 1 addition & 0 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ CCL

ttnn.all_gather
ttnn.reduce_scatter
ttnn.experimental.all_reduce


Embedding
Expand Down
4 changes: 1 addition & 3 deletions tests/sweep_framework/sweeps/ccl/all_gather_n300.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case_n300
from ttnn import ShardTensorToMesh

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30
Expand Down
2 changes: 0 additions & 2 deletions tests/sweep_framework/sweeps/ccl/all_gather_n300_focused.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@

from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case_n300

# Override the default timeout in seconds for hang detection.
TIMEOUT = 30
Expand Down
3 changes: 1 addition & 2 deletions tests/sweep_framework/sweeps/ccl/line_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@

import ttnn

from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time
from tests.ttnn.utils_for_testing import start_measuring_time, stop_measuring_time
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from tests.ttnn.unit_tests.operations.test_all_gather import is_unsupported_case
from ttnn import ShardTensorToMesh
Expand Down
2 changes: 0 additions & 2 deletions tests/ttnn/unit_tests/operations/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull
import itertools
from ttnn import ShardTensorToMesh


def is_unsupported_case(input_shape, dim, mem_config, num_devices, num_links, input_dtype, layout):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull
from tests.ttnn.unit_tests.operations.test_all_gather import (
run_all_gather_on_n300_impl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
from models.utility_functions import skip_for_grayskull
from tests.ttnn.unit_tests.operations.test_all_gather import (
is_unsupported_case,
run_all_gather_on_t3000_impl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import pytest
from loguru import logger
import ttnn
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_equal, comp_pcc
from models.utility_functions import skip_for_grayskull, get_devices_for_t3000
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc
from models.utility_functions import skip_for_grayskull


def is_unsupported_case(input_shape, math_op, mem_config, num_devices, num_links, input_dtype, layout):
Expand Down Expand Up @@ -159,6 +159,7 @@ def run_all_reduce_test(
assert not mismatch, f"{i} FAILED: {output}"


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
"num_devices, num_links",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
from loguru import logger
import ttnn
from models.utility_functions import skip_for_grayskull
from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
from tests.ttnn.unit_tests.operations.test_reduce_scatter_post_commit import (
Expand Down
3 changes: 2 additions & 1 deletion ttnn/cpp/ttnn/operations/ccl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ all-reduce), allocation

* All Gather
* Reduce Scatter
* All Reduce (experimental)

### Configurations
For the time being, input and output configurations are expected to match
Expand All @@ -36,7 +37,7 @@ Future Tologies:
* Torus (2d, 3d)

# Future Operations
* All Reduce
* All Reduce (Full support)
* Reduce
* Scatter
* Gather Scatter
Expand Down
43 changes: 25 additions & 18 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,27 @@ AllGather create_all_gather_struct(
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices[i] == input_tensor.device()) {
device_index = i;
if (topology == ttnn::ccl::Topology::Ring) {
// Ring topology
receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring
sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring
} else if (topology == ttnn::ccl::Topology::Linear) {
// Linear topology
bool is_last_chip_in_clockwise_direction = i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = i == 0;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i+1)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i-1)->id());
switch(topology){
case ttnn::ccl::Topology::Ring:{
// Ring topology
receiver_device_id = devices[(i + 1) % num_devices]->id(); // Next device in the ring
sender_device_id = devices[(i + num_devices - 1) % num_devices]->id(); // Previous device in the ring
break;
}
case ttnn::ccl::Topology::Linear:{
// Linear topology
bool is_last_chip_in_clockwise_direction = i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = i == 0;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i+1)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at(i-1)->id());
break;
}
default:
TT_FATAL(false, "Invalid Topology {}, Accepted topologies are Ring and Linear currently", topology);
}
break;
}
Expand Down Expand Up @@ -136,7 +143,7 @@ AllGatherConfig::AllGatherConfig(Tensor const& input_tensor, Tensor const& outpu


void AllGather::validate(const std::vector<Tensor> &input_tensors) const {
TT_FATAL(input_tensors.size() == 1, "Error");
TT_FATAL(input_tensors.size() == 1, "Error, Input tensor size should be 1 but has {}", input_tensors.size());
const auto& input_tensor = input_tensors[0];
const auto& layout = input_tensors[0].get_layout();
const auto& dtype = input_tensors[0].get_dtype();
Expand All @@ -148,9 +155,9 @@ void AllGather::validate(const std::vector<Tensor> &input_tensors) const {
// TODO: Validate ring
TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to all_gather need to be on device!");
TT_FATAL(input_tensor.buffer() != nullptr , "Operands to all_gather need to be allocated in buffers on device!");
TT_FATAL(this->num_links > 0, "Error");
TT_FATAL(this->num_links > 0, "Error, num_links should be more than 0 but has {}", this->num_links);
TT_FATAL(this->num_links <= input_tensor.device()->compute_with_storage_grid_size().y, "Worker cores used by links are parallelizaed over rows");
TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error");
TT_FATAL(this->receiver_device_id.has_value() || this->sender_device_id.has_value(), "Error, All-gather was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid all-gather configuration. The input mesh tensor or all-gather arguments may be incorrect");

TT_FATAL(input_tensor.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED ||
input_tensor.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED ||
Expand Down Expand Up @@ -196,7 +203,7 @@ namespace ccl {
Tensor all_gather(
const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional<MemoryConfig>& memory_config, const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) {

TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch");
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch");
auto devices = input_tensor.get_workers();
uint32_t num_devices = devices.size();
ttnn::ccl::Topology ccl_topology = topology;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Tensor reduce_scatter(
const std::optional<size_t> user_defined_num_workers,
const std::optional<size_t> user_defined_num_buffers_per_channel) {
ttnn::operations::binary::BinaryOpType binary_op_type = convert_reduce_type_to_eltwise_type(math_op);
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "This op is only supported for Fast Dispatch");
TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "reduce_scatter op is only supported for Fast Dispatch");

auto devices = input_tensor.get_workers();
std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
Expand Down Expand Up @@ -110,7 +110,7 @@ Tensor reduce_scatter(
break;
}
}
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in reduce scatter op setup");
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error, Reduce-scatter was unable to identify either a sender or receiver device ID and atleast one must be identified for a valid Reduce-scatter configuration. The input mesh tensor or Reduce-scatter arguments may be incorrect");

return operation::run(
ttnn::ReduceScatter{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,28 +83,8 @@ Tensor all_reduce(
bool is_linear = topology == ttnn::ccl::Topology::Linear;

const auto& input_tensor = input_tensors.at(0);
uint32_t num_devices = devices.size();
uint32_t device_index = 0; // Initialize device index
std::optional<chip_id_t> receiver_device_id = std::nullopt; // Initialize receiver device ID
std::optional<chip_id_t> sender_device_id = std::nullopt; // Initialize sender device ID
for (uint32_t i = 0; i < num_devices; ++i) {
if (devices.at(i) == input_tensor.device()) {

bool is_last_chip_in_clockwise_direction = is_linear && i == (num_devices - 1);
bool is_last_chip_in_counter_clockwise_direction = is_linear && i == 0;
device_index = i;
receiver_device_id = is_last_chip_in_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + 1) % num_devices)->id());
sender_device_id = is_last_chip_in_counter_clockwise_direction ?
std::nullopt :
std::optional<chip_id_t>(devices.at((i + num_devices - 1) % num_devices)->id());
break;
}
}
TT_FATAL(receiver_device_id != std::nullopt || sender_device_id != std::nullopt, "Error in all reduce op setup");

auto shape = input_tensor.shape();
auto shape = input_tensor.get_logical_shape();
auto rank = shape.rank();

uint32_t merged_dim_size = 1;
Expand Down

0 comments on commit 204e3aa

Please sign in to comment.