diff --git a/tt_metal/impl/device/multi_device.cpp b/tt_metal/impl/device/multi_device.cpp index 2676c125a6b..5b6d767de23 100644 --- a/tt_metal/impl/device/multi_device.cpp +++ b/tt_metal/impl/device/multi_device.cpp @@ -52,11 +52,16 @@ DeviceMesh::DeviceMesh(const DeviceGrid& device_grid, const DeviceIds &device_id for (int i = 0; i < num_requested_devices; i++) { mesh_devices.emplace_back(device_ids[i], managed_devices.at(galaxy_device_ids[i])); } + this->num_rows = num_rows; + this->num_cols = num_cols; } else { managed_devices = tt::tt_metal::detail::CreateDevices(device_ids, num_command_queues, l1_small_size, trace_region_size); for (int i = 0; i < num_requested_devices; i++) { mesh_devices.emplace_back(device_ids[i], managed_devices.at(device_ids[i])); } + // TODO: support concept of rows/cols in other systems + this->num_rows = 0; + this->num_cols = 0; } for (const auto& [dev_id, dev]: mesh_devices) { @@ -71,18 +76,15 @@ DeviceMesh::~DeviceMesh() { } } - -Device* DeviceMesh::get_device(int queried_device_id) -{ +Device* DeviceMesh::get_device(int logical_device_id) const { for (const auto& [device_id, device] : mesh_devices) { - if (device_id == queried_device_id) { + if (device_id == logical_device_id) { return device; } } TT_THROW("User has provided an invalid device index"); } - std::vector DeviceMesh::get_devices() const { std::vector devices; @@ -92,6 +94,31 @@ std::vector DeviceMesh::get_devices() const return devices; } +Device* DeviceMesh::get_device(int row_idx, int col_idx) const { + TT_FATAL( + this->num_rows != 0 and this->num_cols != 0, + "#10419, Current device mesh does not support indexing by row or col indices."); + TT_FATAL(row_idx >= 0 and row_idx < this->num_rows, "Invalid row index."); + TT_FATAL(col_idx >= 0 and col_idx < this->num_cols, "Invalid col index."); + int idx = row_idx * this->num_cols + col_idx; + return this->mesh_devices[idx].second; +} + +std::vector DeviceMesh::get_devices_on_row(int row_idx) const { + std::vector devices; + for (int col_idx = 0; col_idx < this->num_cols; ++col_idx) { + devices.push_back(this->get_device(row_idx, col_idx)); + } + return devices; +} + +std::vector DeviceMesh::get_devices_on_column(int col_idx) const { + std::vector devices; + for (int row_idx = 0; row_idx < this->num_rows; ++row_idx) { + devices.push_back(this->get_device(row_idx, col_idx)); + } + return devices; +} const DeviceIds DeviceMesh::get_device_ids() const { diff --git a/tt_metal/impl/device/multi_device.hpp b/tt_metal/impl/device/multi_device.hpp index b33299cc88a..cf401357f11 100644 --- a/tt_metal/impl/device/multi_device.hpp +++ b/tt_metal/impl/device/multi_device.hpp @@ -34,13 +34,20 @@ class DeviceMesh DeviceMesh &operator=(DeviceMesh &&) = delete; std::vector get_devices() const; - Device* get_device(int queried_device_id); + Device *get_device(int logical_device_id) const; + Device *get_device(int row_idx, int col_idx) const; + std::vector get_devices_on_row(int row_idx) const; + std::vector get_devices_on_column(int col_idx) const; const DeviceIds get_device_ids() const; int num_devices() const; void close_devices(); + + private: + int num_rows; + int num_cols; }; diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index dae19e847e4..7b920d36805 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -26,14 +26,41 @@ void py_module(py::module& module) { py::arg("l1_small_size"), py::arg("trace_region_size"), py::arg("num_command_queues")) - .def("get_device", &ttnn::multi_device::DeviceMesh::get_device, py::return_value_policy::reference) .def("get_num_devices", &ttnn::multi_device::DeviceMesh::num_devices) .def("get_device_ids", &ttnn::multi_device::DeviceMesh::get_device_ids) + .def( + "get_device", + py::overload_cast(&ttnn::multi_device::DeviceMesh::get_device, py::const_), + py::return_value_policy::reference) + .def( + "get_device", + py::overload_cast(&ttnn::multi_device::DeviceMesh::get_device, py::const_), + py::return_value_policy::reference) .def("get_devices", &ttnn::multi_device::DeviceMesh::get_devices, py::return_value_policy::reference, R"doc( Get the devices in the device mesh. Returns: List[Device]: The devices in the device mesh. + )doc") + .def( + "get_devices_on_row", + &ttnn::multi_device::DeviceMesh::get_devices_on_row, + py::return_value_policy::reference, + R"doc( + Get the devices in a row of the device mesh. + + Returns: + List[Device]: The devices on a row in the device mesh. + )doc") + .def( + "get_devices_on_column", + &ttnn::multi_device::DeviceMesh::get_devices_on_column, + py::return_value_policy::reference, + R"doc( + Get the devices in a row of the device mesh. + + Returns: + List[Device]: The devices on a row in the device mesh. )doc"); module.def(