Skip to content

Commit

Permalink
#0: New version of configure_command_queue_programs
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-dma committed Nov 25, 2024
1 parent dd4ab7a commit 951f2b5
Show file tree
Hide file tree
Showing 8 changed files with 181 additions and 9 deletions.
61 changes: 54 additions & 7 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2682,6 +2682,49 @@ void Device::compile_command_queue_programs() {
}
}

void Device::configure_command_queue_programs_new() {
chip_id_t device_id = this->id();
chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id);
Device *mmio_device = tt::DevicePool::instance().get_active_device(mmio_device_id);
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id);
log_debug(tt::LogMetal, "Device {} - Channel {}", this->id_, channel);

std::vector<uint32_t> zero = {0x0}; // Reset state in case L1 Clear is disabled.
std::vector<uint32_t> pointers;
uint32_t cq_size = this->sysmem_manager().get_cq_size();
TT_ASSERT(this->command_queue_programs.size() == 1);

Program& command_queue_program = *this->command_queue_programs[0];
uint8_t num_hw_cqs = this->num_hw_cqs();

// Reset host-side command queue pointers
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(mmio_device_id);
uint32_t host_issue_q_rd_ptr = dispatch_constants::get(dispatch_core_type).get_host_command_queue_addr(CommandQueueHostAddrType::ISSUE_Q_RD);
uint32_t host_issue_q_wr_ptr = dispatch_constants::get(dispatch_core_type).get_host_command_queue_addr(CommandQueueHostAddrType::ISSUE_Q_WR);
uint32_t host_completion_q_wr_ptr = dispatch_constants::get(dispatch_core_type).get_host_command_queue_addr(CommandQueueHostAddrType::COMPLETION_Q_WR);
uint32_t host_completion_q_rd_ptr = dispatch_constants::get(dispatch_core_type).get_host_command_queue_addr(CommandQueueHostAddrType::COMPLETION_Q_RD);
uint32_t cq_start = dispatch_constants::get(dispatch_core_type).get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED);
pointers.resize(cq_start/sizeof(uint32_t));
for (uint8_t cq_id = 0; cq_id < num_hw_cqs; cq_id++) {
// Reset the host manager's pointer for this command queue
this->sysmem_manager_->reset(cq_id);

pointers[host_issue_q_rd_ptr / sizeof(uint32_t)] = (cq_start + get_absolute_cq_offset(channel, cq_id, cq_size)) >> 4;
pointers[host_issue_q_wr_ptr / sizeof(uint32_t)] = (cq_start + get_absolute_cq_offset(channel, cq_id, cq_size)) >> 4;
pointers[host_completion_q_wr_ptr / sizeof(uint32_t)] = (cq_start + this->sysmem_manager_->get_issue_queue_size(cq_id) + get_absolute_cq_offset(channel, cq_id, cq_size)) >> 4;
pointers[host_completion_q_rd_ptr / sizeof(uint32_t)] = (cq_start + this->sysmem_manager_->get_issue_queue_size(cq_id) + get_absolute_cq_offset(channel, cq_id, cq_size)) >> 4;

tt::Cluster::instance().write_sysmem(pointers.data(), pointers.size() * sizeof(uint32_t), get_absolute_cq_offset(channel, cq_id, cq_size), mmio_device_id, get_umd_channel(channel));
}

// Write device-side cq pointers
configure_dispatch_cores(this);

// Run the cq program
detail::ConfigureDeviceWithProgram(this, command_queue_program, true);
tt::Cluster::instance().l1_barrier(this->id());
}

// Writes issue and completion queue pointers to device and in sysmem and loads fast dispatch program onto dispatch cores
void Device::configure_command_queue_programs() {
chip_id_t device_id = this->id();
Expand All @@ -2698,7 +2741,7 @@ void Device::configure_command_queue_programs() {
TT_ASSERT(this->command_queue_programs.size() == 1);
} else {
uint32_t program_size = tt::Cluster::instance().get_device_tunnel_depth(device_id) == 1 ? 2 : 1;
if (getenv("TT_METAL_NEW"))
if (llrt::OptionsG.get_use_new_fd_init())
program_size = 1;
TT_ASSERT(this->command_queue_programs.size() == program_size);
}
Expand Down Expand Up @@ -2786,7 +2829,7 @@ void Device::configure_command_queue_programs() {

detail::ConfigureDeviceWithProgram(this, command_queue_program, true);
tt::Cluster::instance().l1_barrier(this->id());
if (device_id != mmio_device_id && !getenv("TT_METAL_NEW")) {
if (device_id != mmio_device_id && !llrt::OptionsG.get_use_new_fd_init()) {
if (tt::Cluster::instance().get_device_tunnel_depth(device_id) == 1) {
// first or only remote device on the tunnel, launch fd2 kernels on mmio device for all remote devices.
Program &mmio_command_queue_program = *this->command_queue_programs[1];
Expand Down Expand Up @@ -2823,7 +2866,7 @@ void Device::init_command_queue_device() {

if (llrt::OptionsG.get_skip_loading_fw()) {
detail::EnablePersistentKernelCache();
if (getenv("TT_METAL_NEW")) {
if (llrt::OptionsG.get_use_new_fd_init()) {
log_warning("Running new FD init");
this->compile_command_queue_programs_new();
} else {
Expand All @@ -2832,7 +2875,7 @@ void Device::init_command_queue_device() {
}
detail::DisablePersistentKernelCache();
} else {
if (getenv("TT_METAL_NEW")) {
if (llrt::OptionsG.get_use_new_fd_init()) {
log_warning("Running new FD init");
this->compile_command_queue_programs_new();
} else {
Expand All @@ -2845,11 +2888,15 @@ void Device::init_command_queue_device() {
TT_ASSERT(this->command_queue_programs.size() == 1);
} else {
uint32_t program_size = tt::Cluster::instance().get_device_tunnel_depth(this->id()) == 1 ? 2 : 1;
if (getenv("TT_METAL_NEW"))
if (llrt::OptionsG.get_use_new_fd_init())
program_size = 1;
TT_ASSERT(this->command_queue_programs.size() == program_size);
}
this->configure_command_queue_programs();
if (llrt::OptionsG.get_use_new_fd_init()) {
this->configure_command_queue_programs_new();
} else {
this->configure_command_queue_programs();
}
Program& command_queue_program = *this->command_queue_programs[0];
command_queue_program.finalize(this);

Expand All @@ -2866,7 +2913,7 @@ void Device::init_command_queue_device() {
}
}

if (!this->is_mmio_capable() && !getenv("TT_METAL_NEW")) {
if (!this->is_mmio_capable() && !llrt::OptionsG.get_use_new_fd_init()) {
if (tt::Cluster::instance().get_device_tunnel_depth(this->id()) == 1) {
chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(this->id());
Device *mmio_device = tt::DevicePool::instance().get_active_device(mmio_device_id);
Expand Down
1 change: 1 addition & 0 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class Device {
void compile_command_queue_programs();
void compile_command_queue_programs_new();
void configure_command_queue_programs();
void configure_command_queue_programs_new();
void clear_l1_state();
void get_associated_dispatch_phys_cores(
std::unordered_map<chip_id_t, std::unordered_set<CoreCoord>> &my_dispatch_cores,
Expand Down
47 changes: 45 additions & 2 deletions tt_metal/impl/dispatch/arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// SPDX-License-Identifier: Apache-2.0

#include "dispatch_kernels.hpp"
#include "impl/device/device_pool.hpp"
#include "tt_metal/detail/tt_metal.hpp"

#define DISPATCH_MAX_UPSTREAM 4
Expand Down Expand Up @@ -90,7 +91,8 @@ static const std::vector<dispatch_kernel_node_t> two_card_arch_2cq = {

std::vector<FDKernel *> node_id_to_kernel;

std::unique_ptr<Program> create_mmio_cq_program(Device *device) {
// Helper function to get the right struct for dispatch kernels. TODO: replace with reading yaml file later?
inline std::vector<dispatch_kernel_node_t> get_nodes(Device *device) {
std::vector<dispatch_kernel_node_t> nodes;
uint32_t num_devices = tt::Cluster::instance().number_of_user_devices();
if (num_devices == 1) { // E150, N150
Expand All @@ -116,6 +118,7 @@ std::unique_ptr<Program> create_mmio_cq_program(Device *device) {
} else { // TG, TGG
TT_FATAL(false, "Not yet implemented!");
}
#if 0
for (auto &node : nodes) {
std::string upstream = "";
for (int id : node.upstream_ids)
Expand All @@ -124,8 +127,14 @@ std::unique_ptr<Program> create_mmio_cq_program(Device *device) {
for (int id : node.downstream_ids)
downstream += fmt::format("{}, ", id);

// tt::log_info("[{}, {}, {}, {}, [{}], [{}], {}, {}, {}]", node.id, node.device_id, node.cq_id, node.kernel_type, upstream, downstream, node.my_noc, node.upstream_noc, node.downstream_noc);
tt::log_info("[{}, {}, {}, {}, [{}], [{}], {}, {}, {}]", node.id, node.device_id, node.cq_id, node.kernel_type, upstream, downstream, node.my_noc, node.upstream_noc, node.downstream_noc);
}
#endif
return nodes;
}

std::unique_ptr<Program> create_mmio_cq_program(Device *device) {
std::vector<dispatch_kernel_node_t> nodes = get_nodes(device);
if (node_id_to_kernel.empty()) {
// Do setup of kernel objects one time at the beginning, since they (1) don't need a valid Device until fields
// are populated, and (2) need to be connected to kernel objects for devices that aren't being created yet.
Expand Down Expand Up @@ -176,3 +185,37 @@ std::unique_ptr<Program> create_mmio_cq_program(Device *device) {
detail::CompileProgram(device, *cq_program_ptr, /*fd_bootloader_mode=*/true);
return cq_program_ptr;
}

void configure_dispatch_cores(Device *device) {
// Set up completion_queue_writer core. This doesn't actually have a kernel so keep it out of the struct and config
// it here. TODO: should this be in the struct?
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(device->id());
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id());
auto &my_dispatch_constants = dispatch_constants::get(dispatch_core_type);
uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED);
uint32_t cq_size = device->sysmem_manager().get_cq_size();
std::vector<uint32_t> zero = {0x0};
for (uint8_t cq_id = 0; cq_id < device->num_hw_cqs(); cq_id++) {
tt_cxy_pair completion_q_writer_location = dispatch_core_manager::instance().completion_queue_writer_core(device->id(), channel, cq_id);
Device *mmio_device = tt::DevicePool::instance().get_active_device(completion_q_writer_location.chip);
uint32_t completion_q_wr_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR);
uint32_t completion_q_rd_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD);
uint32_t completion_q0_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT);
uint32_t completion_q1_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT);
// Initialize completion queue write pointer and read pointer copy
uint32_t issue_queue_size = device->sysmem_manager().get_issue_queue_size(cq_id);
uint32_t completion_queue_start_addr = cq_start + issue_queue_size + get_absolute_cq_offset(channel, cq_id, cq_size);
uint32_t completion_queue_start_addr_16B = completion_queue_start_addr >> 4;
std::vector<uint32_t> completion_queue_wr_ptr = {completion_queue_start_addr_16B};
detail::WriteToDeviceL1(mmio_device, completion_q_writer_location, completion_q_rd_ptr, completion_queue_wr_ptr, dispatch_core_type);
detail::WriteToDeviceL1(mmio_device, completion_q_writer_location, completion_q_wr_ptr, completion_queue_wr_ptr, dispatch_core_type);
detail::WriteToDeviceL1(mmio_device, completion_q_writer_location, completion_q0_last_event_ptr, zero, dispatch_core_type);
detail::WriteToDeviceL1(mmio_device, completion_q_writer_location, completion_q1_last_event_ptr, zero, dispatch_core_type);
}
std::vector<dispatch_kernel_node_t> nodes = get_nodes(device);
for (int idx = 0; idx < node_id_to_kernel.size(); idx++) {
if (nodes.at(idx).device_id == device->id()) {
node_id_to_kernel[idx]->ConfigureCore();
}
}
}
3 changes: 3 additions & 0 deletions tt_metal/impl/dispatch/arch.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0
#pragma once

std::unique_ptr<Program> create_mmio_cq_program(Device *device);

void configure_dispatch_cores(Device *device);
64 changes: 64 additions & 0 deletions tt_metal/impl/dispatch/dispatch_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "dispatch_kernels.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "impl/debug/dprint_server.hpp"

#define UNUSED_LOGICAL_CORE tt_cxy_pair(this->device->id(), 0, 0)
Expand Down Expand Up @@ -1492,3 +1493,66 @@ void EthRouterKernel::CreateKernel() {
};
configure_kernel_variant(dispatch_kernel_file_names[PACKET_ROUTER_MUX], compile_args, defines, false, false, false);
}

void PrefetchKernel::ConfigureCore() {
// Only H-type prefetchers need L1 configuration
if (this->config.is_h_variant.value()) {
// Initialize the FetchQ
uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(this->device->id());
auto &my_dispatch_constants = dispatch_constants::get(GetCoreType());
uint32_t cq_start = my_dispatch_constants.get_host_command_queue_addr(CommandQueueHostAddrType::UNRESERVED);
uint32_t cq_size = device->sysmem_manager().get_cq_size();
std::vector<uint32_t> prefetch_q(my_dispatch_constants.prefetch_q_entries(), 0);
uint32_t prefetch_q_base =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::UNRESERVED);
std::vector<uint32_t> prefetch_q_rd_ptr_addr_data = {
(uint32_t)(prefetch_q_base + my_dispatch_constants.prefetch_q_size())};
uint32_t prefetch_q_rd_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_RD);
uint32_t prefetch_q_pcie_rd_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::PREFETCH_Q_PCIE_RD);
uint32_t completion_q_wr_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_WR);
uint32_t completion_q_rd_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q_RD);
uint32_t dispatch_message_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
uint32_t completion_q0_last_event_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT);
uint32_t completion_q1_last_event_ptr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT);
std::vector<uint32_t> prefetch_q_pcie_rd_ptr_addr_data = {
get_absolute_cq_offset(channel, cq_id, cq_size) + cq_start};
detail::WriteToDeviceL1(
device, this->logical_core, prefetch_q_rd_ptr, prefetch_q_rd_ptr_addr_data, GetCoreType());
detail::WriteToDeviceL1(
device, this->logical_core, prefetch_q_pcie_rd_ptr, prefetch_q_pcie_rd_ptr_addr_data, GetCoreType());
detail::WriteToDeviceL1(device, this->logical_core, prefetch_q_base, prefetch_q, GetCoreType());
}
}

void DispatchKernel::ConfigureCore() {
// For all dispatchers, need to clear the dispatch message
std::vector<uint32_t> zero = {0x0};
auto &my_dispatch_constants = dispatch_constants::get(GetCoreType());
uint32_t dispatch_message_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
detail::WriteToDeviceL1(device, this->logical_core, dispatch_message_addr, zero, GetCoreType());

// For DISPATCH_D, need to clear completion q events
if (!this->config.is_h_variant.value() && this->config.is_d_variant.value()) {
uint32_t completion_q0_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q0_LAST_EVENT);
uint32_t completion_q1_last_event_ptr = my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::COMPLETION_Q1_LAST_EVENT);
detail::WriteToDeviceL1(device, logical_core, completion_q0_last_event_ptr, zero, GetCoreType());
detail::WriteToDeviceL1(device, logical_core, completion_q1_last_event_ptr, zero, GetCoreType());
}
}

void DispatchSKernel::ConfigureCore() {
// Just need to clear the dispatch message
std::vector<uint32_t> zero = {0x0};
auto &my_dispatch_constants = dispatch_constants::get(GetCoreType());
uint32_t dispatch_message_addr =
my_dispatch_constants.get_device_command_queue_addr(CommandQueueDeviceAddrType::DISPATCH_MESSAGE);
detail::WriteToDeviceL1(device, this->logical_core, dispatch_message_addr, zero, GetCoreType());
}
5 changes: 5 additions & 0 deletions tt_metal/impl/dispatch/dispatch_kernels.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include "impl/device/device.hpp"
#include "impl/program/program.hpp"
Expand All @@ -19,6 +20,7 @@ class FDKernel {
virtual void CreateKernel() = 0;
virtual void GenerateStaticConfigs() = 0;
virtual void GenerateDependentConfigs() = 0;
virtual void ConfigureCore() {}; // Overridden for specific kernels that need host-side configuration
static FDKernel *Generate(int node_id, uint8_t cq_id, noc_selection_t noc_selection, DispatchWorkerType type);

void AddUpstreamKernel(FDKernel *upstream) { this->upstream_kernels.push_back(upstream); }
Expand Down Expand Up @@ -302,6 +304,7 @@ class PrefetchKernel : public FDKernel {
void CreateKernel() override;
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
void ConfigureCore() override;
const prefetch_config_t &GetConfig() { return this->config; }

private:
Expand All @@ -318,6 +321,7 @@ class DispatchKernel : public FDKernel {
void CreateKernel() override;
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
void ConfigureCore() override;
const dispatch_config_t &GetConfig() { return this->config; }

private:
Expand All @@ -330,6 +334,7 @@ class DispatchSKernel : public FDKernel {
void CreateKernel() override;
void GenerateStaticConfigs() override;
void GenerateDependentConfigs() override;
void ConfigureCore() override;
const dispatch_s_config_t &GetConfig() { return this->config; }

private:
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/llrt/rtoptions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ RunTimeOptions::RunTimeOptions() {
enable_dispatch_data_collection = true;
}

if (getenv("TT_METAL_NEW")) {
this->use_new_fd_init = true;
}

if (getenv("TT_METAL_GTEST_ETH_DISPATCH")) {
this->dispatch_core_type = tt_metal::DispatchCoreType::ETH;
}
Expand Down
5 changes: 5 additions & 0 deletions tt_metal/llrt/rtoptions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class RunTimeOptions {

tt_metal::DispatchCoreType dispatch_core_type = tt_metal::DispatchCoreType::WORKER;

bool use_new_fd_init = false;

public:
RunTimeOptions();

Expand Down Expand Up @@ -280,6 +282,9 @@ class RunTimeOptions {
inline bool get_dispatch_data_collection_enabled() { return enable_dispatch_data_collection; }
inline void set_dispatch_data_collection_enabled(bool enable) { enable_dispatch_data_collection = enable; }

inline bool get_use_new_fd_init() { return use_new_fd_init; }
inline void set_use_new_fd_init(bool enable) { use_new_fd_init = enable; }

inline tt_metal::DispatchCoreType get_dispatch_core_type() { return dispatch_core_type; }

private:
Expand Down

0 comments on commit 951f2b5

Please sign in to comment.