-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#12652: switch all-gather to worker initiated edm termination mode #14078
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -302,6 +302,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
worker_defines["INTERLEAVED_MEM_LAYOUT"] = "1"; | ||
} | ||
|
||
constexpr ttnn::ccl::EriscDataMoverTerminationMode edm_termination_mode = ttnn::ccl::EriscDataMoverTerminationMode::WORKER_INITIATED; | ||
bool full_send_both_directions = | ||
(topology == ccl::Topology::Linear || | ||
(topology == ccl::Topology::Ring && | ||
|
@@ -347,7 +348,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
edm_sem_addrs_per_link.at(link), | ||
edm_buffer_addrs_per_link.at(link), | ||
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, | ||
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, | ||
edm_termination_mode, | ||
all_gather_config.get_num_buffers_per_channel(), | ||
input_tensor.device()->id()); | ||
counter_clockwise_edm_builders.emplace_back( | ||
|
@@ -356,7 +357,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
edm_sem_addrs_per_link.at(link), | ||
edm_buffer_addrs_per_link.at(link), | ||
ccl::EriscDataMoverBufferSharingMode::NOT_SHARED, | ||
ccl::EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED, | ||
edm_termination_mode, | ||
all_gather_config.get_num_buffers_per_channel(), | ||
input_tensor.device()->id()); | ||
} | ||
|
@@ -440,6 +441,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
static_cast<uint32_t>(sender_worker_writer_semaphore_id), | ||
static_cast<uint32_t>(cb_num_pages / 2), | ||
static_cast<uint32_t>(num_edm_buffers_per_channel), | ||
static_cast<uint32_t>(edm_termination_mode) | ||
}; | ||
|
||
if (is_sharded) { | ||
|
@@ -474,7 +476,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
static_cast<uint32_t>(input_page_size), | ||
static_cast<uint32_t>(receiver_worker_semaphore_id), | ||
static_cast<uint32_t>(cb_num_pages / 2), | ||
static_cast<uint32_t>(num_edm_buffers_per_channel) | ||
static_cast<uint32_t>(num_edm_buffers_per_channel), | ||
static_cast<uint32_t>(edm_termination_mode) | ||
}; | ||
|
||
log_trace(tt::LogOp, "Worker RR ct args"); | ||
|
@@ -661,12 +664,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
std::vector<bool> is_channel_shrinkable(all_gather_config.get_num_workers_per_link(), false); | ||
std::vector<uint32_t> largest_packets_per_channel(all_gather_config.get_num_workers_per_link(), 0); | ||
|
||
std::vector<uint32_t> clockwise_link_buffer_num_messages_to_send; | ||
std::vector<uint32_t> counter_clockwise_link_buffer_num_messages_to_send; | ||
std::vector<uint32_t> edm_semaphores_base_address; | ||
std::vector<uint32_t> link_buffer_sender_addresses; | ||
clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); | ||
counter_clockwise_link_buffer_num_messages_to_send.reserve(all_gather_config.get_num_workers_per_link()); | ||
edm_semaphores_base_address.reserve(all_gather_config.get_num_workers_per_link()); | ||
link_buffer_sender_addresses.reserve(all_gather_config.get_num_workers_per_link()); | ||
|
||
|
@@ -711,20 +710,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
is_channel_shrinkable.at(b) = shrinkable; | ||
largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size(); | ||
} | ||
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { | ||
// link num messages | ||
clockwise_link_buffer_num_messages_to_send.push_back( | ||
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * | ||
sender_worker_num_transfers.at(i).at(b)); | ||
counter_clockwise_link_buffer_num_messages_to_send.push_back( | ||
(num_full_chunks_per_worker.at(b) + (rem_pages_per_worker.at(b) > 0 ? 1 : 0)) * | ||
receiver_worker_num_transfers.at(i).at(b)); | ||
} | ||
for(uint32_t b = 0; b < all_gather_config.get_num_workers_per_link(); ++b) { | ||
log_trace(tt::LogOp, "rem_pages_per_worker[{}]: {}", b, rem_pages_per_worker.at(b)); | ||
log_trace(tt::LogOp, "num_full_chunks_per_worker[{}]: {}", b, num_full_chunks_per_worker.at(b)); | ||
log_trace(tt::LogOp, "clockwise_link_buffer_num_messages_to_send[{}]: {}", b, clockwise_link_buffer_num_messages_to_send.at(b)); | ||
log_trace(tt::LogOp, "counter_clockwise_link_buffer_num_messages_to_send[{}]: {}", b, counter_clockwise_link_buffer_num_messages_to_send.at(b)); | ||
} | ||
|
||
std::vector<uint32_t> receiver_semaphores_base_address; | ||
|
@@ -750,7 +738,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
auto &sender_edm_builder = is_buffer_in_clockwise_direction(b) ? clockwise_edm_builders.at(i) : counter_clockwise_edm_builders.at(i); | ||
log_trace(tt::LogOp, "Adding sender EDM channel"); | ||
EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = | ||
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, clockwise_link_buffer_num_messages_to_send.at(b), sender_worker_coords); | ||
sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_id, 1, sender_worker_coords); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is changing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll update the code for readability. This just means the channel will actually forward data - the specific value itself doesn't matter so the API should be cleaned up. Thanks for the question. |
||
if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) { | ||
TT_ASSERT(largest_packets_per_channel.at(b) > 0); | ||
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); | ||
|
@@ -770,7 +758,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
auto &receiver_edm_builder = is_buffer_in_clockwise_direction(b) ? counter_clockwise_edm_builders.at(i) : clockwise_edm_builders.at(i); | ||
log_trace(tt::LogOp, "Adding receiver EDM channel"); | ||
EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = | ||
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, counter_clockwise_link_buffer_num_messages_to_send.at(b), receiver_worker_coords); | ||
receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_id, 1, receiver_worker_coords); | ||
if (is_channel_shrinkable.at(b) && largest_packets_per_channel.at(b) > 0) { | ||
TT_ASSERT(largest_packets_per_channel.at(b) > 0); | ||
log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); | ||
|
@@ -805,10 +793,10 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( | |
std::vector<uint32_t> args = { | ||
static_cast<uint32_t>(input_buffer->address()), | ||
static_cast<uint32_t>(output_buffer->address()), | ||
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)), // move to rt | ||
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)), // move to rt | ||
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)), // move to rt | ||
static_cast<uint32_t>(rem_pages_per_worker.at(b)), // move to rt | ||
static_cast<uint32_t>(sender_worker_num_transfers.at(i).at(b)), | ||
static_cast<uint32_t>(num_full_chunks_per_worker.at(b)), | ||
static_cast<uint32_t>(pages_per_eth_l1_buffer.at(b)), | ||
static_cast<uint32_t>(rem_pages_per_worker.at(b)), | ||
static_cast<uint32_t>(tensor_slicer.input_start_page_idx), | ||
static_cast<uint32_t>(tensor_slicer.output_start_page_idx), | ||
static_cast<uint32_t>(tensor_slicer.output_start_addr_offset), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does all gather not use
EriscDataMoverTerminationMode::MESSAGE_COUNT_REACHED
mode at all now?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. Both to simplify host code and to enable migration to fabric we won't be able to use a message_count type of mode anymore.