Skip to content

Commit

Permalink
Fused AllGather+Matmul (#11760)
Browse files Browse the repository at this point in the history
#10415: Adding support for overlapped all_gather and matmul

This PR merges in the fully functional fused op ttnn.experimental.all_gather_matmul for t3k, that overlaps communication with computation to increase efficiency.

Currently, the matmuls that are supported are MatmulMultiCoreReuseMultiCastProgramConfig and MatmulMultiCoreReuseMultiCast1DProgramConfig. For matmul2d, interleaved tensors are supported. For matmul 1d, interleaved tensors and mcast_in0=True is supported and sharded in0 is supported as long as the number of shards equals the number of devices (8 in t3k).

For this op, the all_gather and matmul kernels are changed. However, if the ttnn.experimental.all_gather_matmul (fused) op is not called, then separate calls to all_gather and matmul will operate as normal, with fusion turned off.
  • Loading branch information
avoraTT authored Aug 31, 2024
1 parent 2e14e61 commit 5a4fc17
Show file tree
Hide file tree
Showing 23 changed files with 1,606 additions and 382 deletions.
509 changes: 413 additions & 96 deletions tests/ttnn/unit_tests/operations/test_all_gather_matmul.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/all_gather_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/all_gather_matmul_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/multi_core/all_gather_matmul_op_multi_core.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_common.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/ccl_host_datastructures.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/ccl/line_all_gather/device/line_all_gather_op.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "tt_metal/host_api.hpp"
#include "ttnn/operations/ccl/ccl_host_datastructures.hpp"
#include "ttnn/operations/ccl/ccl_common.hpp"
#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp"
#include "ttnn/operations/ccl/ccl_op_fusion.hpp"


#include "ttnn/run_operation.hpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,7 @@ void kernel_main() {
OpSignaler op_signaler;

if constexpr(fuse_op) {
op_signaler = OpSignaler(
get_compile_time_arg_val(25),
get_compile_time_arg_val(26),
get_compile_time_arg_val(27),
arg_idx
);
op_signaler = OpSignaler(arg_idx);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_ring_gather_utils.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/shared_with_host/sharded_tensor_addr_gen.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_edm_adapters.hpp"
#include "ttnn/cpp/ttnn/operations/ccl/kernel_common/worker_sync_utils.hpp"


void kernel_main() {
uint32_t arg_idx = 0;
Expand Down Expand Up @@ -45,18 +47,26 @@ void kernel_main() {
constexpr uint32_t eth_sender_noc_y = get_compile_time_arg_val(19);
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(20);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(21);
constexpr bool fuse_op = get_compile_time_arg_val(22);

/* Args for overlapped all gather */
OpSignaler op_signaler;

if constexpr(fuse_op) {
op_signaler = OpSignaler(arg_idx);
}

static_assert(half_cb_n_pages > rem_num_pages, "half_cb_n_pages must be greater than or equal to rem_num_pages");

#ifdef SHARDED_MEM_LAYOUT
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(22));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(23);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(24);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(25);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(26);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(27);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(28);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(29) != 0;
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(23));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(24);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(25);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(26);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(27);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(28);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(29);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(30) != 0;
#endif

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
Expand Down Expand Up @@ -138,6 +148,11 @@ void kernel_main() {
pop_filler_pages_from_cb(cb_id_in0, half_cb_n_pages - rem_num_pages);
}

if constexpr(fuse_op) {
// Synchronize and signal that the local tensor slice is available
op_signaler.synchronize_workers_and_signal_op();
}

// num_transfers = num_devices - 1
for (uint32_t i = 1; i < num_transfers; ++i) {
if constexpr(num_full_chunks > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
#include <sstream>
#include <type_traits>

#include "ttnn/operations/experimental/ccl/ccl_op_fusion.hpp"
#include "ttnn/operations/ccl/ccl_op_fusion.hpp"


using namespace tt::constants;
Expand Down Expand Up @@ -207,8 +207,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(

/* All gather fusion */
bool fuse_op = fused_op_signaler.has_value();

// Need a seperate signaler for the sender workers, to handle the first tensor slice that is locally available
std::optional<experimental::ccl::AllGatherFusedOpSignaler> fused_op_signaler_sender_workers;
if (fuse_op) {
fused_op_signaler->init_fused_op(device);
fused_op_signaler_sender_workers = fused_op_signaler.value();
}
auto const& all_gather_config = AllGatherConfig(input_tensor, output_tensor, dim, ring_size, num_links, topology, num_edm_buffers_per_channel, fuse_op);
auto const& topology_config = ttnn::ccl::RingTopology(device, topology, sender_device_id, receiver_device_id, num_links, ring_size, ring_index);
Expand Down Expand Up @@ -263,6 +266,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
constexpr uint32_t max_num_full_send_directions = 2;
// number of worker cores is 2x this since there is 1 worker for the sender buffer and 1 worker for the receiver buffer
uint32_t global_num_workers = num_links * all_gather_config.get_num_eth_buffers_per_edm() * num_full_send_directions;
uint32_t global_num_workers_per_direction = global_num_workers / num_full_send_directions;
uint32_t total_worker_core_pairs_used = global_num_workers;

uint32_t num_input_pages = input_tensor.buffer()->size() / input_page_size;
Expand Down Expand Up @@ -478,6 +482,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
/* All gather fusion */
if (fuse_op) {
fused_op_signaler->init_all_gather(program, device, receiver_workers, receiver_worker_cores);
if (direction == 1) {
fused_op_signaler_sender_workers->init_all_gather(program, device, sender_workers, sender_worker_cores);
}
}

{
Expand Down Expand Up @@ -726,12 +733,15 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(device->ethernet_core_from_logical_core(worker_eth_sender_core).x),
static_cast<uint32_t>(device->ethernet_core_from_logical_core(worker_eth_sender_core).y),
static_cast<uint32_t>(cb_num_pages / 2),
static_cast<uint32_t>(num_edm_buffers_per_channel)
static_cast<uint32_t>(num_edm_buffers_per_channel),

static_cast<bool>(fuse_op && direction == 1)
};

if (is_sharded) {
emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_sender_ct_args, output_pages_per_shard_y, output_pages_per_shard_x);
}

log_trace(tt::LogOp, "Worker {} SW CT args", b);
log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram());
log_trace(tt::LogOp, "\tsender_num_transfers: {}", sender_worker_num_transfers.at(i).at(b));
Expand Down Expand Up @@ -774,6 +784,19 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
emit_sharded_tensor_kernel_rt_args(device, output_tensor, worker_writer_sender_rt_args);
}

if (fuse_op && direction == 1) {
fused_op_signaler_sender_workers->push_all_gather_fused_op_rt_args(
worker_writer_sender_rt_args,
global_num_workers_per_direction,
b,
is_clockwise_direction ? 0 : 1,
std::make_optional<experimental::ccl::CoreSemPair>(
{fused_op_signaler->all_gather_worker_cores_noc[0],
fused_op_signaler->all_gather_worker_sync_semaphore}
)
);
}

log_trace(tt::LogOp, "Worker {} SW rt args", b);
log_trace(tt::LogOp, "\toutput_buffer->address(): {}", output_buffer->address());
log_trace(tt::LogOp, "\tsender_eth_buffer_addrs: {}", sender_eth_buffer_addrs.at(b));
Expand Down Expand Up @@ -936,16 +959,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
emit_sharded_tensor_kernel_ct_args(device, output_tensor, worker_writer_receiver_ct_args, output_pages_per_shard_y, output_pages_per_shard_x);
}

if (fuse_op) {
uint32_t global_num_workers_per_direction = global_num_workers / num_full_send_directions;
fused_op_signaler->emit_all_gather_fused_op_ct_args(worker_writer_receiver_ct_args, global_num_workers_per_direction, b);
} else {
// Push dummy args so that kernel doesn't error out at compile time from the lack of args when fuse_op=false
for (uint32_t w = 0; w < experimental::ccl::AllGatherFusedOpSignaler::get_num_ct_args(); ++w) {
worker_writer_receiver_ct_args.push_back(static_cast<uint32_t>(0));
}
}

log_trace(tt::LogOp, "Worker {} RW ct args", b);
log_trace(tt::LogOp, "\tall_gather_config.is_output_dram(): {}", all_gather_config.is_output_dram());
log_trace(tt::LogOp, "\treceiver_num_transfers: {}", receiver_worker_num_transfers.at(i).at(b));
Expand Down Expand Up @@ -994,7 +1007,12 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(

/* All Gather fusion */
if (fuse_op) {
fused_op_signaler->emit_all_gather_fused_op_rt_args(worker_writer_receiver_rt_args, is_clockwise_direction ? 0 : 1);
fused_op_signaler->push_all_gather_fused_op_rt_args(
worker_writer_receiver_rt_args,
global_num_workers_per_direction,
b,
is_clockwise_direction ? 0 : 1
);
}

log_trace(tt::LogOp, "Worker {} RW rt args", b);
Expand Down
179 changes: 179 additions & 0 deletions ttnn/cpp/ttnn/operations/ccl/ccl_op_fusion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include "tt_metal/host_api.hpp"
#include "tt_metal/impl/program/program.hpp"
#include "ttnn/operations/ccl/ccl_op_fusion.hpp"

namespace ttnn {
namespace experimental {
namespace ccl {

void AllGatherFusedOpSignaler::init_fused_op(
const std::vector<CoreCoord>& fused_op_receiver_cores_noc,
const std::vector<uint32_t>& fused_op_receiver_signal_semaphores
) {
this->fused_op_receiver_cores_noc = fused_op_receiver_cores_noc;
this->fused_op_receiver_signal_semaphores = fused_op_receiver_signal_semaphores;
this->num_fused_op_cores_to_signal = fused_op_receiver_cores_noc.size();

initialized_fused_op = true;
}

void AllGatherFusedOpSignaler::init_all_gather(
Program& program,
Device const* device,

CoreRangeSet const& all_gather_workers,
std::vector<CoreCoord>& all_gather_worker_cores
) {
// Create the sync semaphore for the all gather workers
this->all_gather_worker_sync_semaphore = CreateSemaphore(program, all_gather_workers, 0);

// Get the noc coords for the all gather workers
this->all_gather_worker_cores_noc.clear();
for (const auto& core : all_gather_worker_cores) {
this->all_gather_worker_cores_noc.push_back(device->worker_core_from_logical_core(core));
}
initialized_all_gather = true;
}

void AllGatherFusedOpSignaler::push_all_gather_fused_op_rt_args(
std::vector<uint32_t>& out_rt_args,

uint32_t num_workers_to_sync,
uint32_t curr_worker_index,
uint32_t all_gather_direction,
std::optional<CoreSemPair> start_signal_core_sem_pair
) {
TT_ASSERT(initialized_fused_op && initialized_all_gather, "AllGatherFusedOpSignaler not initialized fully.");

out_rt_args.push_back(static_cast<uint32_t>(num_workers_to_sync));
out_rt_args.push_back(static_cast<uint32_t>(curr_worker_index));
out_rt_args.push_back(static_cast<uint32_t>(this->all_gather_worker_sync_semaphore));

// Push the worker core noc coords
for (const auto& core : this->all_gather_worker_cores_noc) {
out_rt_args.push_back(static_cast<uint32_t>(core.x));
out_rt_args.push_back(static_cast<uint32_t>(core.y));
}

// Push the number of fused op cores to signal
out_rt_args.push_back(static_cast<uint32_t>(this->num_fused_op_cores_to_signal));

// Push the fused op receiver core noc coords
for (const auto& core : this->fused_op_receiver_cores_noc) {
out_rt_args.push_back(static_cast<uint32_t>(core.x));
out_rt_args.push_back(static_cast<uint32_t>(core.y));
}

// Push the fused op signal semaphore addrs. Direction 0: clockwise, Direction 1: counter-clockwise
out_rt_args.push_back(
static_cast<uint32_t>(this->fused_op_receiver_signal_semaphores[all_gather_direction])
);

// Push the params for the start signal. Only wait for/send start signal if all_gather direction is counter clockwise
bool wait_for_start_signal = !start_signal_core_sem_pair.has_value() && all_gather_direction == 1;
bool send_start_signal = start_signal_core_sem_pair.has_value() && all_gather_direction == 1;

out_rt_args.push_back(static_cast<uint32_t>(wait_for_start_signal));
out_rt_args.push_back(static_cast<uint32_t>(send_start_signal));

if (send_start_signal) {
out_rt_args.push_back(static_cast<uint32_t>(start_signal_core_sem_pair->core.x));
out_rt_args.push_back(static_cast<uint32_t>(start_signal_core_sem_pair->core.y));
out_rt_args.push_back(static_cast<uint32_t>(start_signal_core_sem_pair->sem_id));
}

}


// Used to propagate semaphore information from matmul to all_gather in all_gather_matmul op
void MatmulFusedOpSignaler::init_all_gather(
uint32_t num_transfers,
uint32_t ring_size,
uint32_t start_ring_index,
uint32_t tensor_slice_shape_width,
uint32_t output_page_offset,
bool is_clockwise_direction,

uint32_t weight_output_page_offset
) {
this->num_transfers = num_transfers;
this->ring_size = ring_size;
this->start_ring_index = start_ring_index;
this->tensor_slice_shape_width = tensor_slice_shape_width;
this->output_page_offset = output_page_offset;
this->is_clockwise_dir = is_clockwise_direction;

this->weight_output_page_offset = weight_output_page_offset;

initialized_all_gather = true;
}

void MatmulFusedOpSignaler::init_fused_op(
Program& program,
Device const* device,
const std::variant<CoreRange, CoreRangeSet>& core_range_to_signal
) {
// Clear the existing receiver cores
this->fused_op_receiver_cores_noc.clear();

// Visit the variant to handle CoreRange and CoreRangeSet differently
std::visit([&](auto& arg) {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, CoreRange>) {
// Handle CoreRange
const auto& cores = grid_to_cores(arg.start_coord, arg.end_coord, true);
for (auto& core : cores) {
this->fused_op_receiver_cores_noc.push_back(device->worker_core_from_logical_core(core));
}
} else if constexpr (std::is_same_v<T, CoreRangeSet>) {
// Handle CoreRangeSet
for (const auto& range : arg.ranges()) {
const auto& cores = grid_to_cores(range.start_coord, range.end_coord, true);
for (auto& core : cores) {
this->fused_op_receiver_cores_noc.push_back(device->worker_core_from_logical_core(core));
}
}
}
}, core_range_to_signal);

// Create the semaphores
this->fused_op_receiver_signal_semaphores.push_back(CreateSemaphore(program, core_range_to_signal, 0));
this->fused_op_receiver_signal_semaphores.push_back(CreateSemaphore(program, core_range_to_signal, 0));

// Set the number of fused op cores to signal
this->num_fused_op_cores_to_signal = this->fused_op_receiver_cores_noc.size();

initialized_fused_op = true;
}

void MatmulFusedOpSignaler::push_matmul_fused_op_rt_args(
std::vector<uint32_t>& out_rt_args,
bool use_in1_offset
) {
TT_ASSERT(initialized_all_gather && initialized_fused_op, "MatmulFusedOpSignaler not initialized fully.");

out_rt_args.push_back(static_cast<uint32_t>(this->num_transfers));
out_rt_args.push_back(static_cast<uint32_t>(this->ring_size));
out_rt_args.push_back(static_cast<uint32_t>(this->start_ring_index));
out_rt_args.push_back(static_cast<uint32_t>(this->tensor_slice_shape_width));
if (use_in1_offset) {
out_rt_args.push_back(static_cast<uint32_t>(this->weight_output_page_offset));
out_rt_args.push_back(static_cast<uint32_t>((this->ring_size - 1) * this->weight_output_page_offset));
} else {
out_rt_args.push_back(static_cast<uint32_t>(this->output_page_offset));
out_rt_args.push_back(static_cast<uint32_t>((this->ring_size - 1) * this->output_page_offset));
}
out_rt_args.push_back(static_cast<uint32_t>(this->is_clockwise_dir));
out_rt_args.push_back(static_cast<uint32_t>(this->fused_op_receiver_signal_semaphores[0]));
out_rt_args.push_back(static_cast<uint32_t>(this->fused_op_receiver_signal_semaphores[1]));
}



} // namespace ccl
} // namespace experimental
} // namespace ttnn
Loading

0 comments on commit 5a4fc17

Please sign in to comment.