Skip to content

Commit

Permalink
Add negative dim support
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Nov 21, 2024
1 parent a387b58 commit 4ca86ac
Show file tree
Hide file tree
Showing 12 changed files with 73 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
[
(4, 1, [4, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 1, 33, 256], 0, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 1, 256, 32], 0, ttnn.TILE_LAYOUT),
(8, 1, [8, 1, 256, 32], -4, ttnn.TILE_LAYOUT),
(8, 1, [8, 8, 256, 384], 1, ttnn.ROW_MAJOR_LAYOUT),
# (4, 2, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT),
(8, 1, [8, 8, 256, 384], 1, ttnn.TILE_LAYOUT),
(4, 1, [8, 5, 13, 384], 3, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 5, 13, 512], 3, ttnn.ROW_MAJOR_LAYOUT),
(4, 1, [8, 5, 13, 384], -1, ttnn.ROW_MAJOR_LAYOUT),
(8, 1, [8, 5, 13, 512], -1, ttnn.ROW_MAJOR_LAYOUT),
(4, 1, [8, 5, 32, 384], 3, ttnn.TILE_LAYOUT),
(8, 1, [8, 5, 32, 512], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 16384], 3, ttnn.TILE_LAYOUT),
(4, 1, [1, 1, 32, 16384], -1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,18 @@
([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 32], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 64], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 32], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 64], -1, ttnn.TILE_LAYOUT),
([1, 1, 64, 64], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 128], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 128], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 256], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], -1, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], -1, ttnn.TILE_LAYOUT),
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -99,12 +99,12 @@ def test_reduce_scatter_t3k_8chip_nightly(
[
([1, 8, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 1024, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 4, 2048, 1024], -1, ttnn.TILE_LAYOUT),
([1, 1, 32, 512], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], 3, ttnn.TILE_LAYOUT),
([1, 1, 32, 2048], -1, ttnn.TILE_LAYOUT),
([1, 1, 128, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], 3, ttnn.TILE_LAYOUT),
([1, 1, 128, 8192], -1, ttnn.TILE_LAYOUT),
([1, 1, 2048, 1024], 3, ttnn.TILE_LAYOUT),
([1, 1, 2048, 8192], 3, ttnn.TILE_LAYOUT),
# These shapes result in some workers with no work, which is currently
Expand Down
8 changes: 4 additions & 4 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,19 @@
namespace ttnn::operations::ccl {

ttnn::Tensor ExecuteAllGather::invoke(const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

ttnn::Tensor ExecuteAllGather::invoke(
const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
Expand All @@ -30,7 +30,7 @@ ttnn::Tensor ExecuteAllGather::invoke(
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) {
return ttnn::operations::ccl::all_gather(
input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
}

} // namespace ttnn::operations::ccl
4 changes: 2 additions & 2 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace ccl {
struct ExecuteAllGather {
static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> num_workers = std::nullopt,
Expand All @@ -23,7 +23,7 @@ struct ExecuteAllGather {

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
Expand Down
16 changes: 8 additions & 8 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation,
ttnn::pybind_overload_t{
[](const ccl_operation_t& self,
const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) -> ttnn::Tensor {
return self(input_tensor, dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
return self(input_tensor, gather_dim, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
},
py::arg("input_tensor"),
py::arg("dim"),
py::arg("gather_dim"),
py::kw_only(),
py::arg("num_links") = 1,
py::arg("memory_config") = std::nullopt,
Expand All @@ -49,18 +49,18 @@ void bind_all_gather(pybind11::module& module, const ccl_operation_t& operation,
ttnn::pybind_overload_t{
[](const ccl_operation_t& self,
const ttnn::Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::optional<size_t> num_workers,
const std::optional<size_t> num_buffers_per_channel,
const ttnn::ccl::Topology topology) -> ttnn::Tensor {
return self(input_tensor, dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
return self(input_tensor, gather_dim, cluster_axis, mesh_device, num_links, memory_config, num_workers, num_buffers_per_channel, topology);
},
py::arg("input_tensor"),
py::arg("dim"),
py::arg("gather_dim"),
py::arg("cluster_axis"),
py::arg("mesh_device"),
py::kw_only(),
Expand All @@ -84,7 +84,7 @@ void py_bind_all_gather(pybind11::module& module) {
Args:
input_tensor (ttnn.Tensor): multi-device tensor.
dim (int): Dimension to perform operation.
gather_dim (int): Dimension to perform operation.
cluster_axis (int): Provided a MeshTensor, the axis corresponding to MeshDevice to perform the line-all-gather operation on.
mesh_device (MeshDevice): Device mesh to perform the line-all-gather operation on.
* cluster_axis and mesh_device parameters are applicable only for Linear Topology.
Expand Down Expand Up @@ -113,7 +113,7 @@ void py_bind_all_gather(pybind11::module& module) {
memory_config=mem_config,
mesh_mapper=ShardTensor2dMesh(mesh_device, mesh_shape=(1, 8), dims=(-1, -2)))
>>> ttnn_tensor = ttnn.to_device(ttnn_tensor, mesh_device)
>>> output = ttnn.all_gather(ttnn_tensor, dim=0, topology=ttnn.Topology.Ring)
>>> output = ttnn.all_gather(ttnn_tensor, gather_dim=0, topology=ttnn.Topology.Ring)
)doc");
}
Expand Down
17 changes: 15 additions & 2 deletions ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ namespace operations {
namespace ccl {

Tensor all_gather(
const Tensor& input_tensor, const uint32_t dim, const uint32_t num_links, const std::optional<MemoryConfig>& memory_config, const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) {
const Tensor& input_tensor, const int16_t gather_dim, const uint32_t num_links, const std::optional<MemoryConfig>& memory_config, const std::optional<size_t> user_defined_num_workers, const std::optional<size_t> user_defined_num_buffers_per_channel, const ttnn::ccl::Topology topology) {

TT_FATAL(std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr, "all_gather op is only supported for Fast Dispatch");
auto devices = input_tensor.get_workers();
Expand All @@ -185,6 +185,13 @@ Tensor all_gather(
if (num_devices == 2){
ccl_topology = ttnn::ccl::Topology::Linear;
}

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[dim, num_links, memory_config, user_defined_num_workers, user_defined_num_buffers_per_channel, devices, ccl_topology](
Expand All @@ -205,7 +212,7 @@ Tensor all_gather(

Tensor all_gather(
const Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links,
Expand All @@ -218,6 +225,12 @@ Tensor all_gather(
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (gather_dim < 0) ? rank + gather_dim : gather_dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

operation::launch_op(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ namespace ccl {

Tensor all_gather(
const Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t num_links = 1,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<size_t> user_defined_num_workers = std::nullopt,
Expand All @@ -209,7 +209,7 @@ Tensor all_gather(

Tensor all_gather(
const Tensor& input_tensor,
const uint32_t dim,
const int16_t gather_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
const uint32_t num_links = 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ namespace operations{
namespace ccl{
Tensor reduce_scatter(
const Tensor& input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const MemoryConfig& output_mem_config,
Expand All @@ -126,9 +126,15 @@ Tensor reduce_scatter(
ccl_topology = ttnn::ccl::Topology::Linear;
}

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};
operation::launch_op(
[binary_op_type, scatter_dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
[binary_op_type, dim, num_links, output_mem_config, ccl_topology, devices, user_defined_num_workers, user_defined_num_buffers_per_channel](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand All @@ -139,7 +145,7 @@ Tensor reduce_scatter(
ttnn::ccl::reduce_scatter_detail::create_reduce_scatter_struct(
input_tensor,
binary_op_type,
scatter_dim,
dim,
num_links,
output_mem_config,
user_defined_num_workers,
Expand All @@ -158,7 +164,7 @@ Tensor reduce_scatter(

Tensor reduce_scatter(
const Tensor &input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op,
Expand All @@ -174,10 +180,16 @@ Tensor reduce_scatter(
const auto mesh_view = mesh_device.get_view();
std::size_t num_devices = (cluster_axis == 0) ? mesh_view->num_rows() : mesh_view->num_cols();

int16_t rank = input_tensor.get_logical_shape().rank();

int16_t dim = (scatter_dim < 0) ? rank + scatter_dim : scatter_dim;

TT_FATAL(dim >= -rank && dim <= rank - 1 , "Dimension input should be in between -{} and {}, but has {}", rank, rank - 1, dim);

std::vector<Tensor> output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))};

operation::launch_op(
[scatter_dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
[dim, binary_op_type, num_links, output_mem_config, mesh_view, cluster_axis, user_defined_num_workers, user_defined_num_buffers_per_channel, num_devices, topology](
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<std::optional<Tensor>>& optional_output_tensors) mutable -> std::vector<Tensor> {
Expand Down Expand Up @@ -206,7 +218,7 @@ Tensor reduce_scatter(
return operation::run(
ttnn::ReduceScatter{
binary_op_type,
scatter_dim,
dim,
num_links,
num_devices,
device_index,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ namespace operations{
namespace ccl{
Tensor reduce_scatter(
const Tensor &input_tensor,
const uint32_t scatter_split_dim,
const int16_t scatter_split_dim,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
const uint32_t num_links = 1,
const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG,
Expand All @@ -79,7 +79,7 @@ Tensor reduce_scatter(

Tensor reduce_scatter(
const ttnn::Tensor &input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace ttnn::operations::ccl {

ttnn::Tensor ExecuteReduceScatter::invoke(
const ttnn::Tensor& input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links,
const std::optional<ttnn::MemoryConfig>& memory_config,
Expand All @@ -23,7 +23,7 @@ ttnn::Tensor ExecuteReduceScatter::invoke(
}
ttnn::Tensor ExecuteReduceScatter::invoke(
const ttnn::Tensor& input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType math_op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace ccl {
struct ExecuteReduceScatter {
static ttnn::Tensor invoke(
const Tensor &input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
const uint32_t cluster_axis,
const MeshDevice& mesh_device,
ttnn::operations::reduction::ReduceType reduce_op = ttnn::operations::reduction::ReduceType::Sum,
Expand All @@ -29,7 +29,7 @@ struct ExecuteReduceScatter {

static ttnn::Tensor invoke(
const ttnn::Tensor& input_tensor,
const uint32_t scatter_dim,
const int16_t scatter_dim,
ttnn::operations::reduction::ReduceType math_op,
const uint32_t num_links = 1,
const std::optional<ttnn::MemoryConfig>& memory_config = std::nullopt,
Expand Down
Loading

0 comments on commit 4ca86ac

Please sign in to comment.