Skip to content

Commit

Permalink
Convert Hal into a Singleton (#15116)
Browse files Browse the repository at this point in the history
Use thread safe singleton instead of global throughout the code base.
  • Loading branch information
blozano-tt authored Nov 19, 2024
1 parent 34fcef6 commit a356429
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 60 deletions.
6 changes: 5 additions & 1 deletion tests/tt_metal/distributed/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,8 @@ set_target_properties(
${PROJECT_BINARY_DIR}/test/tt_metal/distributed
)

gtest_discover_tests(distributed_unit_tests)
# Dont do this for now
# When the test is probed something is constructed that tries to access a device
# Build machine might not have a device
# We don't use ctest in this project so we shouldn't need this yet
#gtest_discover_tests(distributed_unit_tests)
Original file line number Diff line number Diff line change
Expand Up @@ -2906,9 +2906,6 @@ int main(int argc, char **argv) {
auto slow_dispatch_mode = getenv("TT_METAL_SLOW_DISPATCH_MODE");
TT_FATAL(slow_dispatch_mode, "This test only supports TT_METAL_SLOW_DISPATCH_MODE");

// TODO(abhullar): Have to initialize the HAL explicitly here because it is accessed before Device initializes it
tt::ARCH arch = tt::Cluster::instance().arch();
hal.initialize(arch);
init(argc, argv);

bool pass = true;
Expand Down
80 changes: 80 additions & 0 deletions tt_metal/llrt/get_platform_architecture.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <cstdlib>

#include "tt_metal/common/tt_backend_api_types.hpp"
#include "tt_metal/common/assert.hpp"
#include "tt_metal/third_party/umd/device/cluster.h"

namespace tt::tt_metal {

/**
* @brief Detects the platform architecture based on the environment or hardware.
*
* This function determines the platform architecture by inspecting the environment
* variables or available physical devices. If the environment variable
* `TT_METAL_SIMULATOR_EN` is set, the architecture is retrieved from the
* `ARCH_NAME` environment variable. Otherwise, the architecture is deduced
* by detecting available physical devices.
*
* @return tt::ARCH The detected platform architecture. Returns tt::ARCH::Invalid
* if no valid architecture could be detected.
*
* @note
* - If the system is in simulation mode (`TT_METAL_SIMULATOR_EN` is set),
* the `ARCH_NAME` environment variable must be defined.
* - A fatal error occurs if multiple devices are detected with conflicting
* architectures.
*
* @exception std::runtime_error Throws a fatal error if:
* - `ARCH_NAME` is not set when `TT_METAL_SIMULATOR_EN` is enabled.
* - Multiple devices with inconsistent architectures are detected.
*
* Example usage:
* @code
* #include "tt_metal/common/tt_backend_api_types.hpp"
*
* tt::ARCH arch = tt::tt_metal::get_platform_architecture();
* if (arch == tt::ARCH::Invalid) {
* std::cerr << "Failed to detect architecture!" << std::endl;
* } else {
* std::cout << "Detected architecture: " << tt::get_arch_str(arch) << std::endl;
* }
* @endcode
*
* @see tt::get_arch_from_string
* @see tt::umd::Cluster::detect_available_device_ids
* @see detect_arch
*/
inline tt::ARCH get_platform_architecture() {

auto arch = tt::ARCH::Invalid;
if(std::getenv("TT_METAL_SIMULATOR_EN")) {
auto arch_env = std::getenv("ARCH_NAME");
TT_FATAL(arch_env, "ARCH_NAME env var needed for VCS");
arch = tt::get_arch_from_string(arch_env);
} else {
std::vector<chip_id_t> physical_mmio_device_ids = tt::umd::Cluster::detect_available_device_ids();
if (!physical_mmio_device_ids.empty()) {
arch = detect_arch(physical_mmio_device_ids.at(0));
for (int i = 1; i < physical_mmio_device_ids.size(); ++i) {
chip_id_t device_id = physical_mmio_device_ids.at(i);
tt::ARCH detected_arch = detect_arch(device_id);
TT_FATAL(
arch == detected_arch,
"Expected all devices to be {} but device {} is {}",
get_arch_str(arch),
device_id,
get_arch_str(detected_arch));
}
}
}

return arch;
}

} // namespace tt::tt_metal
45 changes: 16 additions & 29 deletions tt_metal/llrt/hal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,31 @@
// SPDX-License-Identifier: Apache-2.0

#include "hal.hpp"
#include "tt_metal/third_party/umd/device/tt_soc_descriptor.h"
#include "tt_metal/third_party/umd/device/tt_arch_types.h"

#include "tt_metal/common/tt_backend_api_types.hpp"
#include "tt_metal/common/assert.hpp"

#include "get_platform_architecture.hpp"
namespace tt {

namespace tt_metal {

Hal hal;

// This back poitner is a little clunky but necessary at least for now
Hal::Hal() : initialized_(false) {
}

void Hal::initialize(tt::ARCH arch) {
// Hal Constructor determines the platform architecture by using UMD
// Once it knows the architecture it can self initialize architecture specific memory maps
Hal::Hal() : arch_(get_platform_architecture()) {

const std::lock_guard<std::mutex> lock(this->lock);
switch (this->arch_) {
case tt::ARCH::GRAYSKULL: initialize_gs();
break;

if (!this->initialized_) {
switch (arch) {
case tt::ARCH::GRAYSKULL:
initialize_gs();
break;

case tt::ARCH::WORMHOLE_B0:
initialize_wh();
break;

case tt::ARCH::BLACKHOLE:
initialize_bh();
break;

default:
TT_THROW("Unsupported arch for HAL");
}
case tt::ARCH::WORMHOLE_B0: initialize_wh();
break;

this->arch_ = arch;
case tt::ARCH::BLACKHOLE: initialize_bh();
break;

this->initialized_ = true;
case tt::ARCH::Invalid: /*TT_THROW("Unsupported arch for HAL")*/;
break;
}
}

Expand Down
23 changes: 18 additions & 5 deletions tt_metal/llrt/hal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ inline T HalCoreInfoType::get_binary_local_init_addr(uint32_t processor_class_id

class Hal {
private:
std::mutex lock;
bool initialized_;
tt::ARCH arch_;
std::vector<HalCoreInfoType> core_info_;
std::vector<DeviceAddr> dram_bases_;
Expand All @@ -159,8 +157,6 @@ class Hal {
public:
Hal();

void initialize(tt::ARCH arch);

tt::ARCH get_arch() const {return arch_;}

template <typename IndexType, typename SizeType, typename CoordType>
Expand Down Expand Up @@ -300,7 +296,24 @@ inline T Hal::get_binary_local_init_addr(uint32_t programmable_core_type_index,
return this->core_info_[programmable_core_type_index].get_binary_local_init_addr(processor_class_idx, processor_type_idx);
}

extern Hal hal;
class HalSingleton : public Hal {
private:
HalSingleton() = default;
HalSingleton(const HalSingleton&) = delete;
HalSingleton(HalSingleton&&) = delete;
~HalSingleton() = default;

HalSingleton& operator=(const HalSingleton&) = delete;
HalSingleton& operator=(HalSingleton&&) = delete;

public:
static inline HalSingleton& getInstance() {
static HalSingleton instance;
return instance;
}
};

inline auto& hal = HalSingleton::getInstance(); // inline variable requires C++17

} // namespace tt_metal
} // namespace tt
Expand Down
29 changes: 7 additions & 22 deletions tt_metal/llrt/tt_cluster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
#include "tt_metal/llrt/tlb_config.hpp"
#include "tt_metal/common/core_coord.hpp"

#include "get_platform_architecture.hpp"

static constexpr uint32_t HOST_MEM_CHANNELS = 4;
static constexpr uint32_t HOST_MEM_CHANNELS_MASK = HOST_MEM_CHANNELS - 1;

Expand All @@ -67,8 +69,6 @@ Cluster::Cluster() {

this->detect_arch_and_target();

tt_metal::hal.initialize(arch_);

this->generate_cluster_descriptor();

this->initialize_device_drivers();
Expand All @@ -83,26 +83,11 @@ Cluster::Cluster() {
}

void Cluster::detect_arch_and_target() {
if(std::getenv("TT_METAL_SIMULATOR_EN")) {
this->target_type_ = TargetDevice::Simulator;
auto arch_env = getenv("ARCH_NAME");
TT_FATAL(arch_env, "ARCH_NAME env var needed for VCS");
this->arch_ = tt::get_arch_from_string(arch_env);
}else {
this->target_type_ = TargetDevice::Silicon;
std::vector<chip_id_t> physical_mmio_device_ids = tt::umd::Cluster::detect_available_device_ids();
this->arch_ = detect_arch(physical_mmio_device_ids.at(0));
for (int dev_index = 1; dev_index < physical_mmio_device_ids.size(); dev_index++) {
chip_id_t device_id = physical_mmio_device_ids.at(dev_index);
tt::ARCH detected_arch = detect_arch(device_id);
TT_FATAL(
this->arch_ == detected_arch,
"Expected all devices to be {} but device {} is {}",
get_arch_str(this->arch_),
device_id,
get_arch_str(detected_arch));
}
}

this->target_type_ = (std::getenv("TT_METAL_SIMULATOR_EN")) ? TargetDevice::Simulator : TargetDevice::Silicon;

this->arch_ = tt_metal::get_platform_architecture();

#ifdef ARCH_GRAYSKULL
TT_FATAL(
this->arch_ == tt::ARCH::GRAYSKULL,
Expand Down

0 comments on commit a356429

Please sign in to comment.