Skip to content

Commit

Permalink
#12652: switch all-gather to worker initiated edm termination mode
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNijjar committed Oct 22, 2024
1 parent 0326577 commit eefea4a
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void kernel_main() {
volatile uint32_t *receiver_read_sem_addr = reinterpret_cast<volatile uint32_t *>(get_semaphore(get_compile_time_arg_val(1)));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(2);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3);
constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(4));

uint32_t arg_idx = 0;
const uint32_t eth_receiver_l1_base_addr = get_arg_val<uint32_t>(arg_idx++);
Expand All @@ -28,24 +29,31 @@ void kernel_main() {

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;

ccl::edm::WorkerToEdmReader<ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED> reader(
ccl::edm::WorkerToEdmReader<edm_termination_mode> reader(
ttnn::ccl::WorkerXY(eth_receiver_noc_x, eth_receiver_noc_y),
eth_receiver_l1_base_addr,
num_buffers_per_channel,
eth_receiver_l1_semaphore_addr,
(num_full_chunks > 0 ? num_pages_per_full_chunk : rem_num_pages) * page_size,
receiver_read_sem_addr);

bool last_message = false;
for (uint32_t i = 0; i < num_transfers; ++i) {
if (num_full_chunks > 0) {
for (uint32_t c = 0; c < num_full_chunks; ++c) {
reader.wait_for_payload_available();
reader.fetch_payload_blocking(cb_id_in0, num_pages_per_full_chunk, page_size, false);
if constexpr (edm_termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) {
last_message = (i == num_transfers - 1 && c == num_full_chunks - 1 && rem_num_pages == 0);
}
reader.fetch_payload_blocking(cb_id_in0, num_pages_per_full_chunk, page_size, last_message);
}
}
if (rem_num_pages > 0) {
if constexpr (edm_termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) {
last_message = (i == num_transfers - 1);
}
reader.wait_for_payload_available();
reader.fetch_payload_blocking(cb_id_in0, rem_num_pages, page_size, false);
reader.fetch_payload_blocking(cb_id_in0, rem_num_pages, page_size, last_message);
ASSERT(num_pages_per_full_chunk == 0 || num_pages_per_full_chunk > rem_num_pages);
ASSERT(half_cb_n_pages > rem_num_pages);
push_filler_pages_to_cb(cb_id_in0, half_cb_n_pages - rem_num_pages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,19 @@ void kernel_main() {
volatile uint32_t *const writer_send_sem_ptr = reinterpret_cast<volatile uint32_t *const >(get_semaphore(get_compile_time_arg_val(4)));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(5);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(6);
constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(7));

ASSERT(half_cb_n_pages > rem_num_pages);

#ifdef SHARDED_MEM_LAYOUT
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(7));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0;
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(8));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0;
#endif

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
Expand Down Expand Up @@ -118,7 +119,7 @@ void kernel_main() {
#endif
#endif

ccl::edm::WorkerToEdmSender<ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED> sender(
ccl::edm::WorkerToEdmSender<edm_termination_mode> sender(
ttnn::ccl::WorkerXY(eth_sender_noc_x, eth_sender_noc_y),
eth_sender_l1_base_addr,
num_buffers_per_channel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
worker_defines["INTERLEAVED_MEM_LAYOUT"] = "1";
}

constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED;
bool full_send_both_directions =
(topology == ccl::Topology::Linear ||
(topology == ccl::Topology::Ring &&
Expand Down Expand Up @@ -347,7 +348,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
edm_sem_addrs_per_link.at(link),
edm_buffer_addrs_per_link.at(link),
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED,
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED,
edm_termination_mode,
all_gather_config.get_num_buffers_per_channel(),
input_tensor.device()->id());
counter_clockwise_edm_builders.emplace_back(
Expand All @@ -356,7 +357,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
edm_sem_addrs_per_link.at(link),
edm_buffer_addrs_per_link.at(link),
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED,
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED,
edm_termination_mode,
all_gather_config.get_num_buffers_per_channel(),
input_tensor.device()->id());
}
Expand Down Expand Up @@ -440,6 +441,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(sender_worker_writer_semaphore_id),
static_cast<uint32_t>(cb_num_pages / 2),
static_cast<uint32_t>(num_edm_buffers_per_channel),
static_cast<uint32_t>(edm_termination_mode)
};

if (is_sharded) {
Expand Down Expand Up @@ -474,7 +476,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(input_page_size),
static_cast<uint32_t>(receiver_worker_semaphore_id),
static_cast<uint32_t>(cb_num_pages / 2),
static_cast<uint32_t>(num_edm_buffers_per_channel)
static_cast<uint32_t>(num_edm_buffers_per_channel),
static_cast<uint32_t>(edm_termination_mode)
};

log_trace(tt::LogOp, "Worker RR ct args");
Expand Down Expand Up @@ -661,12 +664,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
std::vector<bool> is_channel_shrinkable(all_gather_config.get_num_workers_per_link(), false);
std::vector<uint32_t> largest_packets_per_channel(all_gather_config.get_num_workers_per_link(), 0);

std::vector<uint32_t> clockwise_link_buffer_num_messages_to_send;
std::vector<uint32_t> counter_clockwise_link_buffer_num_messages_to_send;
std::vector<uint32_t> edm_semaphores_base_address;
std::vector<uint32_t> link_buffer_sender_addresses;
clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link());
counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link());
edm_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link());
link_buffer_sender_addresses.reserve(all_gather_config.get_num_workers_per_link());

Expand Down Expand Up @@ -711,20 +710,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
is_channel_shrinkable.at(b) = shrinkable;
largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size();
}
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) {
// link num messages
clockwise_link_buffer_num_messages_to_send.push_back(
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) *
sender_worker_num_transfers.at(i).at(b));
counter_clockwise_link_buffer_num_messages_to_send.push_back(
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) *
receiver_worker_num_transfers.at(i).at(b));
}
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) {
log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b));
log_trace(tt::LogOp, "num_full_chunks_per_worker[{}]: {}", b, num_full_chunks_per_worker.at(b));
log_trace(tt::LogOp, "clockwise_link_buffer_num_messages_to_send[{}]: {}", b, clockwise_link_buffer_num_messages_to_send.at(b));
log_trace(tt::LogOp, "counter_clockwise_link_buffer_num_messages_to_send[{}]: {}", b, counter_clockwise_link_buffer_num_messages_to_send.at(b));
}

std::vector<uint32_t> receiver_semaphores_base_address;
Expand All @@ -750,7 +738,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
auto &sender_edm_builder = is_buffer_in_clockwise_direction(b) ? clockwise_edm_builders.at(i) : counter_clockwise_edm_builders.at(i);
log_trace(tt::LogOp, "Adding sender EDM channel");
EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info =
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, clockwise_link_buffer_num_messages_to_send.at(b), sender_worker_coords);
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, 1, sender_worker_coords);
if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) {
TT_ASSERT(largest_packets_per_channel.at(b) > 0);
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b);
Expand All @@ -770,7 +758,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
auto &receiver_edm_builder = is_buffer_in_clockwise_direction(b) ? counter_clockwise_edm_builders.at(i) : clockwise_edm_builders.at(i);
log_trace(tt::LogOp, "Adding receiver EDM channel");
EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info =
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, counter_clockwise_link_buffer_num_messages_to_send.at(b), receiver_worker_coords);
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, 1, receiver_worker_coords);
if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) {
TT_ASSERT(largest_packets_per_channel.at(b) > 0);
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b);
Expand Down Expand Up @@ -805,10 +793,10 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
std::vector<uint32_t> args = {
static_cast<uint32_t>(input_buffer->address()),
static_cast<uint32_t>(output_buffer->address()),
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)), // move to rt
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)), // move to rt
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)), // move to rt
static_cast<uint32_t>(rem_pages_per_worker.at(b)), // move to rt
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)),
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)),
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)),
static_cast<uint32_t>(rem_pages_per_worker.at(b)),
static_cast<uint32_t>(tensor_slicer.input_start_page_idx),
static_cast<uint32_t>(tensor_slicer.output_start_page_idx),
static_cast<uint32_t>(tensor_slicer.output_start_addr_offset),
Expand Down

0 comments on commit eefea4a

Please sign in to comment.