Skip to content

Commit

Permalink
#14361: Finish implementing metal 1.0 API
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Nov 22, 2024
1 parent b5a7147 commit b43685b
Show file tree
Hide file tree
Showing 18 changed files with 470 additions and 261 deletions.
1 change: 1 addition & 0 deletions tt_metal/impl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ set(IMPL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/debug/watcher_device_reader.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trace/trace.cpp
${CMAKE_CURRENT_SOURCE_DIR}/trace/trace_buffer.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event/event.cpp
)

add_library(impl OBJECT ${IMPL_SRC})
Expand Down
19 changes: 18 additions & 1 deletion tt_metal/impl/buffers/buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

#include "tt_metal/impl/buffers/buffer.hpp"

#include "tt_metal/buffer.hpp"
#include "tt_metal/common/assert.hpp"
#include "tt_metal/common/math.hpp"
#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/impl/allocator/allocator.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/types.hpp"

#include <algorithm>
#include <mutex>
#include <string>
#include <utility>
#include "tt_metal/common/base.hpp"
#include "tt_metal/impl/buffers/buffer_constants.hpp"
Expand Down Expand Up @@ -514,6 +515,22 @@ DeviceAddr ShardSpecBuffer::size() const {
return shape_in_pages_[0] * shape_in_pages_[1];
}

v1::BufferHandle v1::CreateBuffer(InterleavedBufferConfig config) { return v1::BufferHandle{v0::CreateBuffer(config)}; }

void v1::DeallocateBuffer(BufferHandle buffer) { v0::DeallocateBuffer(*buffer); }

void v1::WriteToBuffer(BufferHandle buffer, stl::Span<const std::byte> host_buffer) {
detail::WriteToBuffer(*buffer, stl::Span<const uint8_t>{reinterpret_cast<const std::uint8_t *>(host_buffer.data()), host_buffer.size()});
}

void v1::ReadFromBuffer(BufferHandle buffer, stl::Span<std::byte> host_buffer, bool shard_order) {
detail::ReadFromBuffer(*buffer, reinterpret_cast<std::uint8_t *>(host_buffer.data()), shard_order);
}

void v1::ReadFromShard(BufferHandle buffer, stl::Span<std::byte> host_buffer, std::uint32_t core_id) {
detail::ReadShard(*buffer, reinterpret_cast<std::uint8_t *>(host_buffer.data()), core_id);
}

} // namespace tt_metal
} // namespace tt

Expand Down
93 changes: 85 additions & 8 deletions tt_metal/impl/device/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
// SPDX-License-Identifier: Apache-2.0

#include <string>
#include <chrono>
#include <type_traits>
#include <thread>
#include "tt_metal/device.hpp"
#include "common/core_coord.hpp"
#include "tt_metal/host_api.hpp"
#include "tt_metal/jit_build/genfiles.hpp"
#include "tt_metal/impl/device/device.hpp"
Expand All @@ -14,7 +15,6 @@
#include "tt_metal/detail/tt_metal.hpp"
#include "impl/debug/dprint_server.hpp"
#include "impl/debug/watcher_server.hpp"
#include "common/env_lib.hpp"
#include "tt_metal/impl/dispatch/kernels/packet_queue_ctrl.hpp"
#include "common/utils.hpp"
#include "llrt/llrt.hpp"
Expand All @@ -28,13 +28,14 @@
#include "tt_metal/impl/sub_device/sub_device_manager.hpp"
#include "tt_metal/impl/sub_device/sub_device_types.hpp"
#include "tt_metal/tt_stl/span.hpp"
#include "tt_metal/types.hpp"

namespace tt {

namespace tt_metal {

Device::Device(
chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal, uint32_t worker_core, uint32_t completion_queue_reader_core) :
chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap, bool minimal, uint32_t worker_core, uint32_t completion_queue_reader_core) :
id_(device_id), worker_thread_core(worker_core), completion_queue_reader_core(completion_queue_reader_core), work_executor(worker_core, device_id) {
ZoneScoped;
tunnel_device_dispatch_workers_ = {};
Expand Down Expand Up @@ -208,7 +209,7 @@ void Device::initialize_cluster() {
log_info(tt::LogMetal, "AI CLK for device {} is: {} MHz", this->id_, ai_clk);
}

void Device::initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap) {
void Device::initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap) {
// Create the default sub-device manager representing the entire chip
this->next_sub_device_manager_id_ = {0};
auto [sub_device_manager, _] = this->sub_device_managers_.insert_or_assign(this->get_next_sub_device_manager_id(), std::make_unique<detail::SubDeviceManager>(this, this->initialize_allocator(l1_small_size, trace_region_size, l1_bank_remap)));
Expand All @@ -220,7 +221,7 @@ void Device::initialize_default_sub_device_state(size_t l1_small_size, size_t tr

}

std::unique_ptr<Allocator> Device::initialize_allocator(size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap) {
std::unique_ptr<Allocator> Device::initialize_allocator(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap) {
ZoneScoped;
const metal_SocDescriptor &soc_desc = tt::Cluster::instance().get_soc_desc(this->id_);
CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->id_);
Expand All @@ -246,7 +247,7 @@ std::unique_ptr<Allocator> Device::initialize_allocator(size_t l1_small_size, si
.core_type_from_noc_coord_table = {}, // Populated later
.worker_log_to_physical_routing_x = soc_desc.worker_log_to_physical_routing_x,
.worker_log_to_physical_routing_y = soc_desc.worker_log_to_physical_routing_y,
.l1_bank_remap = l1_bank_remap,
.l1_bank_remap = {l1_bank_remap.begin(), l1_bank_remap.end()},
.compute_grid = CoreRangeSet(CoreRange(CoreCoord(0, 0), CoreCoord(compute_size.x - 1, compute_size.y - 1))),
.alignment = std::max(hal.get_alignment(HalMemType::DRAM), hal.get_alignment(HalMemType::L1)),
.disable_interleaved = false});
Expand Down Expand Up @@ -2930,7 +2931,7 @@ void Device::initialize_synchronous_sw_cmd_queue() {
}
}

bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap, bool minimal) {
bool Device::initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap, bool minimal) {
ZoneScoped;
log_info(tt::LogMetal, "Initializing device {}. Program cache is {}enabled", this->id_, this->program_cache.is_enabled() ? "": "NOT ");
log_debug(tt::LogMetal, "Running with {} cqs ", num_hw_cqs);
Expand Down Expand Up @@ -3665,6 +3666,82 @@ const std::vector<SubDeviceId> &Device::get_sub_device_ids() const {
return this->active_sub_device_manager_->get_sub_device_ids();
}

size_t v1::GetNumAvailableDevices() { return tt::Cluster::instance().number_of_user_devices(); }

size_t v1::GetNumPCIeDevices() { return tt::Cluster::instance().number_of_pci_devices(); }

chip_id_t v1::GetPCIeDeviceID(chip_id_t device_id) {
return tt::Cluster::instance().get_associated_mmio_device(device_id);
}

v1::DeviceHandle v1::CreateDevice(chip_id_t device_id, CreateDeviceOptions options) {
ZoneScoped;

tt::DevicePool::initialize(
{device_id},
options.num_hw_cqs,
options.l1_small_size,
options.trace_region_size,
options.dispatch_core_type,
options.l1_bank_remap);

return tt::DevicePool::instance().get_active_device(device_id);
}

bool v1::CloseDevice(DeviceHandle device) { return v0::CloseDevice(device); }

void v1::DeallocateBuffers(DeviceHandle device) { device->deallocate_buffers(); }

void v1::DumpDeviceProfileResults(DeviceHandle device, const CoreRangeSet &worker_cores, bool last_dump) {
auto worker_cores_vec = corerange_to_cores(worker_cores);
detail::DumpDeviceProfileResults(device, worker_cores_vec, last_dump);
}

ARCH v1::GetArch(DeviceHandle device) { return device->arch(); }

chip_id_t v1::GetId(DeviceHandle device) { return device->id(); }

int v1::GetNumDramChannels(DeviceHandle device) { return device->num_dram_channels(); }

std::uint32_t v1::GetL1SizePerCore(DeviceHandle device) { return device->l1_size_per_core(); }

CoreCoord v1::GetComputeWithStorageGridSize(DeviceHandle device) { return device->compute_with_storage_grid_size(); }

CoreCoord v1::GetDramGridSize(DeviceHandle device) { return device->dram_grid_size(); }

void v1::EnableProgramCache(DeviceHandle device) { device->enable_program_cache(); }

void v1::DisableAndClearProgramCache(DeviceHandle device) { device->disable_and_clear_program_cache(); }

void v1::PushWork(DeviceHandle device, std::function<void()> work, bool blocking) {
device->push_work(std::move(work), blocking);
}

void v1::Synchronize(DeviceHandle device) { device->synchronize(); }

std::vector<CoreCoord> v1::GetEthernetSockets(DeviceHandle device, chip_id_t connected_chip_id) {
return device->get_ethernet_sockets(connected_chip_id);
}

std::uint32_t v1::GetNumBanks(DeviceHandle device, BufferType buffer_type) { return device->num_banks(buffer_type); }

std::int32_t v1::GetBankOffset(DeviceHandle device, BufferType buffer_type, std::uint32_t bank_id) {
return device->bank_offset(buffer_type, bank_id);
}

tt::stl::Span<const std::uint32_t> v1::BankIdsFromLogicalCore(
DeviceHandle device, BufferType buffer_type, CoreCoord logical_core) {
return device->bank_ids_from_logical_core(buffer_type, logical_core);
}

float v1::GetSfpuEps(DeviceHandle device) { return device->sfpu_eps(); }

float v1::GetSfpuNan(DeviceHandle device) { return device->sfpu_nan(); }

float v1::GetSfpuInf(DeviceHandle device) { return device->sfpu_inf(); }

std::size_t v1::GetNumProgramCacheEntries(DeviceHandle device) { return device->num_program_cache_entries(); }

} // namespace tt_metal

} // namespace tt
8 changes: 4 additions & 4 deletions tt_metal/impl/device/device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class Device {
const uint8_t num_hw_cqs,
std::size_t l1_small_size,
std::size_t trace_region_size,
const std::vector<uint32_t> &l1_bank_remap = {},
tt::stl::Span<const std::uint32_t> l1_bank_remap = {},
bool minimal = false,
uint32_t worker_core = 0,
uint32_t completion_queue_reader_core = 0);
Expand Down Expand Up @@ -253,9 +253,9 @@ class Device {

// Checks that the given arch is on the given pci_slot and that it's responding
// Puts device into reset
bool initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap = {}, bool minimal = false);
bool initialize(const uint8_t num_hw_cqs, size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {}, bool minimal = false);
void initialize_cluster();
std::unique_ptr<Allocator> initialize_allocator(size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap = {});
std::unique_ptr<Allocator> initialize_allocator(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap = {});
void initialize_build();
void initialize_device_kernel_defines();
void build_firmware();
Expand Down Expand Up @@ -383,7 +383,7 @@ class Device {
void remove_sub_device_manager(SubDeviceManagerId sub_device_manager_id);
const std::vector<SubDeviceId> &get_sub_device_ids() const;
private:
void initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, const std::vector<uint32_t> &l1_bank_remap);
void initialize_default_sub_device_state(size_t l1_small_size, size_t trace_region_size, tt::stl::Span<const std::uint32_t> l1_bank_remap);
SubDeviceManagerId get_next_sub_device_manager_id();
void reset_sub_devices_state(const std::unique_ptr<detail::SubDeviceManager>& sub_device_manager);
void MarkAllocationsUnsafe();
Expand Down
15 changes: 6 additions & 9 deletions tt_metal/impl/device/device_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include <numa.h>

#include "tt_metal/detail/tt_metal.hpp"
#include "tt_metal/impl/debug/noc_logging.hpp"
#include "tt_metal/impl/debug/watcher_server.hpp"
#include "tt_metal/impl/device/device_handle.hpp"

using namespace tt::tt_metal;
Expand Down Expand Up @@ -179,19 +181,19 @@ void DevicePool::initialize(
size_t l1_small_size,
size_t trace_region_size,
DispatchCoreType dispatch_core_type,
const std::vector<uint32_t> &l1_bank_remap) noexcept {
tt::stl::Span<const std::uint32_t> l1_bank_remap) noexcept {
ZoneScoped;
log_debug(tt::LogMetal, "DevicePool initialize");
tt::tt_metal::dispatch_core_manager::initialize(dispatch_core_type, num_hw_cqs);

if (_inst == nullptr) {
static DevicePool device_pool(device_ids, num_hw_cqs, l1_small_size, trace_region_size, l1_bank_remap);
static DevicePool device_pool{};
_inst = &device_pool;
}
_inst->l1_small_size = l1_small_size;
_inst->trace_region_size = trace_region_size;
_inst->num_hw_cqs = num_hw_cqs;
_inst->l1_bank_remap = l1_bank_remap;
_inst->l1_bank_remap.assign(l1_bank_remap.begin(), l1_bank_remap.end());
// Track the thread where the Device Pool was created. Certain functions
// modifying the state of this instance, for example those responsible for
// (un)registering worker threads, can only be called in the creation thread
Expand Down Expand Up @@ -388,12 +390,7 @@ void DevicePool::init_firmware_on_active_devices() const {
}
}

DevicePool::DevicePool(
std::vector<chip_id_t> device_ids,
const uint8_t num_hw_cqs,
size_t l1_small_size,
size_t trace_region_size,
const std::vector<uint32_t>& l1_bank_remap) {
DevicePool::DevicePool() {
ZoneScoped;
log_debug(tt::LogMetal, "DevicePool constructor");
bool use_numa_node_based_thread_binding = parse_env("TT_METAL_NUMA_BASED_AFFINITY", false);
Expand Down
13 changes: 3 additions & 10 deletions tt_metal/impl/device/device_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@

#pragma once

#include "tt_cluster_descriptor_types.h"
#include "tt_metal/host_api.hpp"
#include "impl/debug/dprint_server.hpp"
#include "impl/debug/noc_logging.hpp"
#include "impl/debug/watcher_server.hpp"
#include "tt_metal/impl/device/device.hpp"
#include "tt_metal/impl/device/device_handle.hpp"
#include "tt_metal/third_party/umd/device/tt_cluster_descriptor.h"
namespace tt {
namespace tt_metal::detail {

Expand Down Expand Up @@ -41,7 +39,7 @@ class DevicePool {
size_t l1_small_size,
size_t trace_region_size,
tt_metal::DispatchCoreType dispatch_core_type,
const std::vector<uint32_t> &l1_bank_remap = {}) noexcept;
tt::stl::Span<const std::uint32_t> l1_bank_remap = {}) noexcept;

tt_metal::v1::DeviceHandle get_active_device(chip_id_t device_id) const;
std::vector<tt_metal::v1::DeviceHandle> get_all_active_devices() const;
Expand All @@ -53,12 +51,7 @@ class DevicePool {
const std::unordered_set<std::thread::id>& get_worker_thread_ids() const;
private:
~DevicePool();
DevicePool(
std::vector<chip_id_t> device_ids,
const uint8_t num_hw_cqs,
size_t l1_small_size,
size_t trace_region_size,
const std::vector<uint32_t> &l1_bank_remap);
DevicePool();
uint8_t num_hw_cqs;
size_t l1_small_size;
size_t trace_region_size;
Expand Down
36 changes: 36 additions & 0 deletions tt_metal/impl/dispatch/command_queue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#include "allocator/allocator.hpp"
#include "debug_tools.hpp"
#include "dev_msgs.h"
#include "device/device_handle.hpp"
#include "llrt/hal.hpp"
#include "noc/noc_parameters.h"
#include "tt_metal/command_queue.hpp"
#include "tt_metal/common/assert.hpp"
#include "tt_metal/common/logger.hpp"
#include "tt_metal/detail/tt_metal.hpp"
Expand Down Expand Up @@ -3482,6 +3484,40 @@ void CommandQueue::run_command_impl(const CommandInterface& command) {
log_trace(LogDispatch, "{} running {} complete", this->name(), command.type);
}

v1::CommandQueueHandle v1::GetCommandQueue(DeviceHandle device, std::uint8_t cq_id) {
return v1::CommandQueueHandle{device, cq_id};
}

v1::CommandQueueHandle v1::GetDefaultCommandQueue(DeviceHandle device) { return GetCommandQueue(device, 0); }

void v1::EnqueueReadBuffer(CommandQueueHandle cq, BufferHandle buffer, std::byte *dst, bool blocking) {
v0::EnqueueReadBuffer(GetDevice(cq)->command_queue(GetId(cq)), *buffer, dst, blocking);
}

void v1::EnqueueWriteBuffer(CommandQueueHandle cq, BufferHandle buffer, const std::byte *src, bool blocking) {
v0::EnqueueWriteBuffer(GetDevice(cq)->command_queue(GetId(cq)), *buffer, src, blocking);
}

void v1::EnqueueProgram(CommandQueueHandle cq, ProgramHandle &program, bool blocking) {
v0::EnqueueProgram(GetDevice(cq)->command_queue(GetId(cq)), program, blocking);
}

void v1::Finish(CommandQueueHandle cq, tt::stl::Span<const SubDeviceId> sub_device_ids) {
v0::Finish(GetDevice(cq)->command_queue(GetId(cq)));
}

void v1::SetLazyCommandQueueMode(bool lazy) {
detail::SetLazyCommandQueueMode(lazy);
}

v1::DeviceHandle v1::GetDevice(CommandQueueHandle cq) {
return cq.device;
}

std::uint8_t v1::GetId(CommandQueueHandle cq) {
return cq.id;
}

} // namespace tt::tt_metal

std::ostream& operator<<(std::ostream& os, EnqueueCommandType const& type) {
Expand Down
Loading

0 comments on commit b43685b

Please sign in to comment.