Skip to content

Commit

Permalink
#0: Remove dynamic allocation of sub-devices/expected workers pairs, …
Browse files Browse the repository at this point in the history
…and pass them as separate spans
  • Loading branch information
tt-aho committed Nov 12, 2024
1 parent e5cb76d commit 5c58e48
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 86 deletions.
4 changes: 4 additions & 0 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3672,6 +3672,10 @@ void Device::remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id)
this->sub_device_managers_.erase(sub_device_manager);
}

const std::vector<SubDeviceId> &Device::get_sub_device_ids() const {
return this->active_sub_device_manager_->get_sub_device_ids();
}

} // namespace tt_metal

} // namespace tt
1 change: 1 addition & 0 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class Device {
void load_sub_device_manager(SubDeviceManagerId sub_device_manager_id);
void clear_loaded_sub_device_manager();
void remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id);
const std::vector<SubDeviceId> &get_sub_device_ids() const;
private:
void initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap);
void reset_worker_launch_message_buffer_state(uint32_t num_entries);
Expand Down
144 changes: 74 additions & 70 deletions tt_metal/impl/dispatch/command_queue.cpp

Large diffs are not rendered by default.

37 changes: 25 additions & 12 deletions tt_metal/impl/dispatch/command_queue.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class EnqueueReadBufferCommand : public Command {
Device* device;
uint32_t command_queue_id;
NOC noc_index;
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed;
tt::stl::Span<const uint32_t> expected_num_workers_completed;
tt::stl::Span<const SubDeviceId> sub_device_ids;
uint32_t src_page_index;
uint32_t pages_to_read;

Expand All @@ -92,7 +93,8 @@ class EnqueueReadBufferCommand : public Command {
Buffer& buffer,
void* dst,
SystemMemoryManager& manager,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t src_page_index = 0,
std::optional<uint32_t> pages_to_read = std::nullopt);

Expand All @@ -115,7 +117,8 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand {
Buffer& buffer,
void* dst,
SystemMemoryManager& manager,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t src_page_index = 0,
std::optional<uint32_t> pages_to_read = std::nullopt) :
EnqueueReadBufferCommand(
Expand All @@ -126,6 +129,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand {
dst,
manager,
expected_num_workers_completed,
sub_device_ids,
src_page_index,
pages_to_read) {}
};
Expand All @@ -144,7 +148,8 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
Buffer& buffer,
void* dst,
SystemMemoryManager& manager,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
const CoreCoord& core,
uint32_t bank_base_address,
uint32_t src_page_index = 0,
Expand All @@ -157,6 +162,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand {
dst,
manager,
expected_num_workers_completed,
sub_device_ids,
src_page_index,
pages_to_read),
core(core),
Expand All @@ -179,7 +185,8 @@ class EnqueueWriteBufferCommand : public Command {
NOC noc_index;
const void* src;
const Buffer& buffer;
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed;
tt::stl::Span<const uint32_t> expected_num_workers_completed;
tt::stl::Span<const SubDeviceId> sub_device_ids;
uint32_t bank_base_address;
uint32_t padded_page_size;
uint32_t dst_page_index;
Expand All @@ -195,7 +202,8 @@ class EnqueueWriteBufferCommand : public Command {
const void* src,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
Expand All @@ -222,7 +230,8 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand {
const void* src,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
uint32_t padded_page_size,
uint32_t dst_page_index = 0,
Expand All @@ -236,6 +245,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand {
manager,
issue_wait,
expected_num_workers_completed,
sub_device_ids,
bank_base_address,
padded_page_size,
dst_page_index,
Expand All @@ -261,7 +271,8 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
const void* src,
SystemMemoryManager& manager,
bool issue_wait,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
uint32_t bank_base_address,
const std::shared_ptr<const BufferPageMapping>& buffer_page_mapping,
const CoreCoord& core,
Expand All @@ -277,6 +288,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand {
manager,
issue_wait,
expected_num_workers_completed,
sub_device_ids,
bank_base_address,
padded_page_size,
dst_page_index,
Expand Down Expand Up @@ -346,7 +358,8 @@ class EnqueueRecordEventCommand : public Command {
NOC noc_index;
SystemMemoryManager& manager;
uint32_t event_id;
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed;
tt::stl::Span<const uint32_t> expected_num_workers_completed;
tt::stl::Span<const SubDeviceId> sub_device_ids;
bool clear_count;
bool write_barrier;

Expand All @@ -357,7 +370,8 @@ class EnqueueRecordEventCommand : public Command {
NOC noc_index,
SystemMemoryManager& manager,
uint32_t event_id,
tt::stl::Span<const std::pair<uint32_t, uint32_t>> expected_num_workers_completed,
tt::stl::Span<const uint32_t> expected_num_workers_completed,
tt::stl::Span<const SubDeviceId> sub_device_ids,
bool clear_count = false,
bool write_barrier = true);

Expand Down Expand Up @@ -571,9 +585,8 @@ class HWCommandQueue {
void increment_num_entries_in_completion_q();
void set_exit_condition();

WorkerConfigBufferMgr& get_config_buffer_mgr(SubDeviceId sub_device_id);
WorkerConfigBufferMgr& get_config_buffer_mgr(uint32_t index);
void reset_config_buffer_mgr(const uint32_t num_entries);
std::vector<std::pair<uint32_t, uint32_t>> get_expected_workers_completed(tt::stl::Span<const SubDeviceId> sub_device_ids) const;

friend void EnqueueTraceImpl(CommandQueue& cq, uint32_t trace_id, bool blocking);
friend void EnqueueProgramImpl(
Expand Down
8 changes: 4 additions & 4 deletions tt_metal/impl/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1514,9 +1514,9 @@ uint32_t detail::Program_::get_sem_base_addr(Device *device, CoreCoord logical_c
// TODO: This restriction can be lifted once we have support for programs spanning multiple sub-devices
// Semaphores across sub-devices are expected to have the same address
TT_FATAL(sub_device_ids.size() == 1, "get_sem_base_addr currently only supports programs spanning a single sub-device");
auto sub_device_id = sub_device_ids[0];
auto sub_device_index = sub_device_ids[0].to_index();
uint32_t base_addr = device->using_fast_dispatch
? this->last_used_command_queue_for_testing->get_config_buffer_mgr(sub_device_id).get_last_slot_addr(
? this->last_used_command_queue_for_testing->get_config_buffer_mgr(sub_device_index).get_last_slot_addr(
programmable_core_type)
: hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG);

Expand All @@ -1536,9 +1536,9 @@ uint32_t detail::Program_::get_cb_base_addr(Device *device, CoreCoord logical_co
// TODO: This restriction can be lifted once this function is changed to return a vector of addresses
// Addresses are not the same across sub-devices
TT_FATAL(sub_device_ids.size() == 1, "get_sem_base_addr currently only supports programs spanning a single sub-device");
auto sub_device_id = sub_device_ids[0];
auto sub_device_index = sub_device_ids[0].to_index();
uint32_t base_addr = device->using_fast_dispatch
? this->last_used_command_queue_for_testing->get_config_buffer_mgr(sub_device_id).get_last_slot_addr(
? this->last_used_command_queue_for_testing->get_config_buffer_mgr(sub_device_index).get_last_slot_addr(
programmable_core_type)
: hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG);

Expand Down
16 changes: 16 additions & 0 deletions tt_metal/impl/sub_device/sub_device_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <vector>

#include "tt_metal/impl/sub_device/sub_device_manager.hpp"

#include "tt_metal/common/assert.hpp"
Expand All @@ -26,6 +28,7 @@ SubDeviceManager::SubDeviceManager(
device_(device) {
TT_ASSERT(device != nullptr, "Device must not be null");
this->validate_sub_devices();
this->populate_sub_device_ids();
this->populate_num_cores();
this->populate_sub_allocators();
this->populate_noc_data();
Expand All @@ -45,6 +48,7 @@ SubDeviceManager::SubDeviceManager(Device *device, std::unique_ptr<Allocator> &&
this->sub_devices_ = {SubDevice(std::array{
CoreRangeSet(CoreRange({0, 0}, {compute_grid_size.x - 1, compute_grid_size.y - 1})),
CoreRangeSet(std::move(active_eth_core_ranges))})};
this->populate_sub_device_ids();
// No need to validate sub-devices since this constructs a sub-device of the entire grid
this->populate_num_cores();
this->sub_device_allocators_.push_back(std::move(global_allocator));
Expand All @@ -68,6 +72,10 @@ SubDeviceManager::~SubDeviceManager() {

uint8_t SubDeviceManager::num_sub_devices() const { return this->sub_devices_.size(); }

const std::vector<SubDeviceId> &SubDeviceManager::get_sub_device_ids() const {
return this->sub_device_ids_;
}

const SubDevice& SubDeviceManager::sub_device(SubDeviceId sub_device_id) const {
auto sub_device_index = this->get_sub_device_index(sub_device_id);
return sub_devices_[sub_device_index];
Expand Down Expand Up @@ -143,6 +151,7 @@ uint8_t SubDeviceManager::get_sub_device_index(SubDeviceId sub_device_id) const
}

void SubDeviceManager::validate_sub_devices() const {
TT_FATAL(this->sub_devices_.size() <= SubDeviceManager::MAX_NUM_SUB_DEVICES, "Too many sub devices specified");
// Validate sub device cores fit inside the device grid
const auto& compute_grid_size = this->device_->compute_with_storage_grid_size();
CoreRange device_worker_cores = CoreRange({0, 0}, {compute_grid_size.x - 1, compute_grid_size.y - 1});
Expand Down Expand Up @@ -181,6 +190,13 @@ void SubDeviceManager::validate_sub_devices() const {
}
}

void SubDeviceManager::populate_sub_device_ids() {
this->sub_device_ids_.resize(this->num_sub_devices());
for (uint8_t i = 0; i < this->num_sub_devices(); ++i) {
this->sub_device_ids_[i] = SubDeviceId{i};
}
}

void SubDeviceManager::populate_num_cores() {
for (const auto& sub_device : this->sub_devices_) {
for (uint32_t i = 0; i < NumHalProgrammableCoreTypes; ++i) {
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/impl/sub_device/sub_device_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class SubDeviceManager {

~SubDeviceManager();

const std::vector<SubDeviceId> &get_sub_device_ids() const;

const SubDevice &sub_device(SubDeviceId sub_device_id) const;
const vector_memcpy_aligned<uint32_t> &noc_mcast_data(SubDeviceId sub_device_id) const;
const vector_memcpy_aligned<uint32_t> &noc_unicast_data(SubDeviceId sub_device_id) const;
Expand All @@ -65,12 +67,14 @@ class SubDeviceManager {
private:
void validate_sub_devices() const;
uint8_t get_sub_device_index(SubDeviceId sub_device_id) const;
void populate_sub_device_ids();
void populate_num_cores();
void populate_sub_allocators();
void populate_noc_data();

// TODO: We have a max number of sub-devices, so we can use a fixed size array
std::vector<SubDevice> sub_devices_;
std::vector<SubDeviceId> sub_device_ids_;
Device *device_;

DeviceAddr local_l1_size_;
Expand Down

0 comments on commit 5c58e48

Please sign in to comment.