Skip to content

Commit

Permalink
#10166: add device mesh apis to query by row and col
Browse files Browse the repository at this point in the history
  • Loading branch information
aliuTT committed Jul 18, 2024
1 parent f7b8797 commit d2aa166
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 7 deletions.
37 changes: 32 additions & 5 deletions tt_metal/impl/device/multi_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,16 @@ DeviceMesh::DeviceMesh(const DeviceGrid& device_grid, const DeviceIds &device_id
for (int i = 0; i < num_requested_devices; i++) {
mesh_devices.emplace_back(device_ids[i], managed_devices.at(galaxy_device_ids[i]));
}
this->num_rows = num_rows;
this->num_cols = num_cols;
} else {
managed_devices = tt::tt_metal::detail::CreateDevices(device_ids, num_command_queues, l1_small_size, trace_region_size);
for (int i = 0; i < num_requested_devices; i++) {
mesh_devices.emplace_back(device_ids[i], managed_devices.at(device_ids[i]));
}
// TODO: support concept of rows/cols in other systems
this->num_rows = 0;
this->num_cols = 0;
}

for (const auto& [dev_id, dev]: mesh_devices) {
Expand All @@ -71,18 +76,15 @@ DeviceMesh::~DeviceMesh() {
}
}


Device* DeviceMesh::get_device(int queried_device_id)
{
Device* DeviceMesh::get_device(int logical_device_id) const {
for (const auto& [device_id, device] : mesh_devices) {
if (device_id == queried_device_id) {
if (device_id == logical_device_id) {
return device;
}
}
TT_THROW("User has provided an invalid device index");
}


std::vector<Device*> DeviceMesh::get_devices() const
{
std::vector<Device*> devices;
Expand All @@ -92,6 +94,31 @@ std::vector<Device*> DeviceMesh::get_devices() const
return devices;
}

Device* DeviceMesh::get_device(int row_idx, int col_idx) const {
TT_FATAL(
this->num_rows != 0 and this->num_cols != 0,
"#10419, Current device mesh does not support indexing by row or col indices.");
TT_FATAL(row_idx >= 0 and row_idx < this->num_rows, "Invalid row index.");
TT_FATAL(col_idx >= 0 and col_idx < this->num_cols, "Invalid col index.");
int idx = row_idx * this->num_cols + col_idx;
return this->mesh_devices[idx].second;
}

std::vector<Device*> DeviceMesh::get_devices_on_row(int row_idx) const {
std::vector<Device*> devices;
for (int col_idx = 0; col_idx < this->num_cols; ++col_idx) {
devices.push_back(this->get_device(row_idx, col_idx));
}
return devices;
}

std::vector<Device*> DeviceMesh::get_devices_on_column(int col_idx) const {
std::vector<Device*> devices;
for (int row_idx = 0; row_idx < this->num_rows; ++row_idx) {
devices.push_back(this->get_device(row_idx, col_idx));
}
return devices;
}

const DeviceIds DeviceMesh::get_device_ids() const
{
Expand Down
9 changes: 8 additions & 1 deletion tt_metal/impl/device/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,20 @@ class DeviceMesh
DeviceMesh &operator=(DeviceMesh &&) = delete;

std::vector<Device*> get_devices() const;
Device* get_device(int queried_device_id);
Device *get_device(int logical_device_id) const;
Device *get_device(int row_idx, int col_idx) const;
std::vector<Device *> get_devices_on_row(int row_idx) const;
std::vector<Device *> get_devices_on_column(int col_idx) const;

const DeviceIds get_device_ids() const;

int num_devices() const;

void close_devices();

private:
int num_rows;
int num_cols;
};


Expand Down
29 changes: 28 additions & 1 deletion ttnn/cpp/pybind11/multi_device.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,41 @@ void py_module(py::module& module) {
py::arg("l1_small_size"),
py::arg("trace_region_size"),
py::arg("num_command_queues"))
.def("get_device", &ttnn::multi_device::DeviceMesh::get_device, py::return_value_policy::reference)
.def("get_num_devices", &ttnn::multi_device::DeviceMesh::num_devices)
.def("get_device_ids", &ttnn::multi_device::DeviceMesh::get_device_ids)
.def(
"get_device",
py::overload_cast<int>(&ttnn::multi_device::DeviceMesh::get_device, py::const_),
py::return_value_policy::reference)
.def(
"get_device",
py::overload_cast<int, int>(&ttnn::multi_device::DeviceMesh::get_device, py::const_),
py::return_value_policy::reference)
.def("get_devices", &ttnn::multi_device::DeviceMesh::get_devices, py::return_value_policy::reference, R"doc(
Get the devices in the device mesh.
Returns:
List[Device]: The devices in the device mesh.
)doc")
.def(
"get_devices_on_row",
&ttnn::multi_device::DeviceMesh::get_devices_on_row,
py::return_value_policy::reference,
R"doc(
Get the devices in a row of the device mesh.
Returns:
List[Device]: The devices on a row in the device mesh.
)doc")
.def(
"get_devices_on_column",
&ttnn::multi_device::DeviceMesh::get_devices_on_column,
py::return_value_policy::reference,
R"doc(
Get the devices in a row of the device mesh.
Returns:
List[Device]: The devices on a row in the device mesh.
)doc");

module.def(
Expand Down

0 comments on commit d2aa166

Please sign in to comment.