Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#12652: switch all-gather to worker initiated edm termination mode #14078

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ void kernel_main() {
volatile uint32_t *receiver_read_sem_addr = reinterpret_cast<volatile uint32_t *>(get_semaphore(get_compile_time_arg_val(1)));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(2);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(3);
constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(4));

uint32_t arg_idx = 0;
const uint32_t eth_receiver_l1_base_addr = get_arg_val<uint32_t>(arg_idx++);
Expand All @@ -28,24 +29,31 @@ void kernel_main() {

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;

ccl::edm::WorkerToEdmReader<ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED> reader(
ccl::edm::WorkerToEdmReader<edm_termination_mode> reader(
ttnn::ccl::WorkerXY(eth_receiver_noc_x, eth_receiver_noc_y),
eth_receiver_l1_base_addr,
num_buffers_per_channel,
eth_receiver_l1_semaphore_addr,
(num_full_chunks > 0 ? num_pages_per_full_chunk : rem_num_pages) * page_size,
receiver_read_sem_addr);

bool last_message = false;
for (uint32_t i = 0; i < num_transfers; ++i) {
if (num_full_chunks > 0) {
for (uint32_t c = 0; c < num_full_chunks; ++c) {
reader.wait_for_payload_available();
reader.fetch_payload_blocking(cb_id_in0, num_pages_per_full_chunk, page_size, false);
if constexpr (edm_termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) {
last_message = (i == num_transfers - 1 && c == num_full_chunks - 1 && rem_num_pages == 0);
}
reader.fetch_payload_blocking(cb_id_in0, num_pages_per_full_chunk, page_size, last_message);
}
}
if (rem_num_pages > 0) {
if constexpr (edm_termination_mode == ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED) {
last_message = (i == num_transfers - 1);
}
reader.wait_for_payload_available();
reader.fetch_payload_blocking(cb_id_in0, rem_num_pages, page_size, false);
reader.fetch_payload_blocking(cb_id_in0, rem_num_pages, page_size, last_message);
ASSERT(num_pages_per_full_chunk == 0 || num_pages_per_full_chunk > rem_num_pages);
ASSERT(half_cb_n_pages > rem_num_pages);
push_filler_pages_to_cb(cb_id_in0, half_cb_n_pages - rem_num_pages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,19 @@ void kernel_main() {
volatile uint32_t *const writer_send_sem_ptr = reinterpret_cast<volatile uint32_t *const >(get_semaphore(get_compile_time_arg_val(4)));
constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(5);
constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(6);
constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = static_cast<ttnn::ccl::EriscDataMoverTerminationMode>(get_compile_time_arg_val(7));

ASSERT(half_cb_n_pages > rem_num_pages);

#ifdef SHARDED_MEM_LAYOUT
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(7));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0;
constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast<tt::tt_metal::TensorMemoryLayout>(get_compile_time_arg_val(8));
constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9);
constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10);
constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11);
constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12);
constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13);
constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14);
constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0;
#endif

constexpr uint32_t cb_id_in0 = tt::CB::c_in0;
Expand Down Expand Up @@ -118,7 +119,7 @@ void kernel_main() {
#endif
#endif

ccl::edm::WorkerToEdmSender<ttnn::ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED> sender(
ccl::edm::WorkerToEdmSender<edm_termination_mode> sender(
ttnn::ccl::WorkerXY(eth_sender_noc_x, eth_sender_noc_y),
eth_sender_l1_base_addr,
num_buffers_per_channel,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
worker_defines["INTERLEAVED_MEM_LAYOUT"] = "1";
}

constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does all gather not use EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED mode at all now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Both to simplify host code and to enable migration to fabric we won't be able to use a message_count type of mode anymore.

bool full_send_both_directions =
(topology == ccl::Topology::Linear ||
(topology == ccl::Topology::Ring &&
Expand Down Expand Up @@ -347,7 +348,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
edm_sem_addrs_per_link.at(link),
edm_buffer_addrs_per_link.at(link),
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED,
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED,
edm_termination_mode,
all_gather_config.get_num_buffers_per_channel(),
input_tensor.device()->id());
counter_clockwise_edm_builders.emplace_back(
Expand All @@ -356,7 +357,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
edm_sem_addrs_per_link.at(link),
edm_buffer_addrs_per_link.at(link),
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED,
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED,
edm_termination_mode,
all_gather_config.get_num_buffers_per_channel(),
input_tensor.device()->id());
}
Expand Down Expand Up @@ -440,6 +441,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(sender_worker_writer_semaphore_id),
static_cast<uint32_t>(cb_num_pages / 2),
static_cast<uint32_t>(num_edm_buffers_per_channel),
static_cast<uint32_t>(edm_termination_mode)
};

if (is_sharded) {
Expand Down Expand Up @@ -474,7 +476,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
static_cast<uint32_t>(input_page_size),
static_cast<uint32_t>(receiver_worker_semaphore_id),
static_cast<uint32_t>(cb_num_pages / 2),
static_cast<uint32_t>(num_edm_buffers_per_channel)
static_cast<uint32_t>(num_edm_buffers_per_channel),
static_cast<uint32_t>(edm_termination_mode)
};

log_trace(tt::LogOp, "Worker RR ct args");
Expand Down Expand Up @@ -661,12 +664,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
std::vector<bool> is_channel_shrinkable(all_gather_config.get_num_workers_per_link(), false);
std::vector<uint32_t> largest_packets_per_channel(all_gather_config.get_num_workers_per_link(), 0);

std::vector<uint32_t> clockwise_link_buffer_num_messages_to_send;
std::vector<uint32_t> counter_clockwise_link_buffer_num_messages_to_send;
std::vector<uint32_t> edm_semaphores_base_address;
std::vector<uint32_t> link_buffer_sender_addresses;
clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link());
counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link());
edm_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link());
link_buffer_sender_addresses.reserve(all_gather_config.get_num_workers_per_link());

Expand Down Expand Up @@ -711,20 +710,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
is_channel_shrinkable.at(b) = shrinkable;
largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size();
}
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) {
// link num messages
clockwise_link_buffer_num_messages_to_send.push_back(
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) *
sender_worker_num_transfers.at(i).at(b));
counter_clockwise_link_buffer_num_messages_to_send.push_back(
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) *
receiver_worker_num_transfers.at(i).at(b));
}
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) {
log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b));
log_trace(tt::LogOp, "num_full_chunks_per_worker[{}]: {}", b, num_full_chunks_per_worker.at(b));
log_trace(tt::LogOp, "clockwise_link_buffer_num_messages_to_send[{}]: {}", b, clockwise_link_buffer_num_messages_to_send.at(b));
log_trace(tt::LogOp, "counter_clockwise_link_buffer_num_messages_to_send[{}]: {}", b, counter_clockwise_link_buffer_num_messages_to_send.at(b));
}

std::vector<uint32_t> receiver_semaphores_base_address;
Expand All @@ -750,7 +738,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
auto &sender_edm_builder = is_buffer_in_clockwise_direction(b) ? clockwise_edm_builders.at(i) : counter_clockwise_edm_builders.at(i);
log_trace(tt::LogOp, "Adding sender EDM channel");
EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info =
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, clockwise_link_buffer_num_messages_to_send.at(b), sender_worker_coords);
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, 1, sender_worker_coords);
Copy link
Contributor

@avoraTT avoraTT Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is changing num_eth_messages_to_forward to 1 going to affect perf? Or is it just a semantic change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the code for readability. This just means the channel will actually forward data - the specific value itself doesn't matter so the API should be cleaned up. Thanks for the question.

if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) {
TT_ASSERT(largest_packets_per_channel.at(b) > 0);
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b);
Expand All @@ -770,7 +758,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
auto &receiver_edm_builder = is_buffer_in_clockwise_direction(b) ? counter_clockwise_edm_builders.at(i) : clockwise_edm_builders.at(i);
log_trace(tt::LogOp, "Adding receiver EDM channel");
EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info =
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, counter_clockwise_link_buffer_num_messages_to_send.at(b), receiver_worker_coords);
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, 1, receiver_worker_coords);
if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) {
TT_ASSERT(largest_packets_per_channel.at(b) > 0);
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b);
Expand Down Expand Up @@ -805,10 +793,10 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper(
std::vector<uint32_t> args = {
static_cast<uint32_t>(input_buffer->address()),
static_cast<uint32_t>(output_buffer->address()),
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)), // move to rt
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)), // move to rt
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)), // move to rt
static_cast<uint32_t>(rem_pages_per_worker.at(b)), // move to rt
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)),
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)),
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)),
static_cast<uint32_t>(rem_pages_per_worker.at(b)),
static_cast<uint32_t>(tensor_slicer.input_start_page_idx),
static_cast<uint32_t>(tensor_slicer.output_start_page_idx),
static_cast<uint32_t>(tensor_slicer.output_start_addr_offset),
Expand Down
Loading