diff --git a/conftest.py b/conftest.py index 8ce28b82872c..d3fc41b414af 100644 --- a/conftest.py +++ b/conftest.py @@ -235,7 +235,11 @@ def pcie_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, devic request.node.pci_ids = device_ids[:num_pcie_devices_requested] mesh_device = ttnn.open_mesh_device( - ttnn.MeshShape(2, 2), dispatch_core_type=get_dispatch_core_type(), **device_params, offset=(0, 1) + ttnn.MeshShape(2, 2), + dispatch_core_type=get_dispatch_core_type(), + **device_params, + offset=(0, 1), + mesh_type=ttnn.MeshType.Ring, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") @@ -283,6 +287,7 @@ def t3k_mesh_device(request, silicon_arch_name, silicon_arch_wormhole_b0, device ttnn.MeshShape(2, 4), dispatch_core_type=get_dispatch_core_type(), **device_params, + mesh_type=ttnn.MeshType.Ring, ) logger.debug(f"multidevice with {mesh_device.get_num_devices()} devices is created") diff --git a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py index 0af21ac87c7a..2861253da1a9 100644 --- a/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py +++ b/models/demos/t3000/llama2_70b/tt/llama_mlp_optimized.py @@ -70,12 +70,11 @@ def load_weights(self): padded_w3 = self.state_dict[w3_str].transpose(-2, -1).view(1, 1, H, H4) # w1: 8k x 4k. width-sharded on 12 banks, 4224 over 12 banks. - device = self.mesh_device.get_device(0) weight_grid = ttnn.CoreRangeSet( { ttnn.CoreRange( ttnn.CoreCoord(0, 0), - ttnn.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1), + ttnn.CoreCoord(self.mesh_device.dram_grid_size().x - 1, self.mesh_device.dram_grid_size().y - 1), ) } ) diff --git a/tests/scripts/tg/run_tg_model_perf_tests.sh b/tests/scripts/tg/run_tg_model_perf_tests.sh index 7cd43da8c897..9501e79e4236 100755 --- a/tests/scripts/tg/run_tg_model_perf_tests.sh +++ b/tests/scripts/tg/run_tg_model_perf_tests.sh @@ -1,6 +1,10 @@ #!/bin/bash run_tg_llm_tests() { + + echo "LOG_METAL: Running run_t3000_llama2_70b_tests" + pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$? + # Merge all the generated reports env python models/perf/merge_perf_results.py; fail+=$? diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 6d7ed90ee8b9..2d2e99504677 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -130,8 +130,9 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { } // Iterate over each row and run line all-gather multiple times. // For each row, send adversarial traffic to the first chip, that can hang the network if the CCL is not tagged. + auto view = MeshDeviceView(*mesh); for (uint32_t row = 0; row < 8; row++) { - auto devs = mesh->get_devices_on_row(row); + auto devs = view.get_devices_on_row(row); std::vector device_ids = {}; for (auto dev : devs) { device_ids.push_back(dev->id()); @@ -189,13 +190,14 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { std::shared_ptr mesh = ttnn::multi_device::open_mesh_device(mesh_shape, 0, 0, 1, DispatchCoreType::WORKER); // Create the outer ring on which Reduce Scatter will be run. This allows us to verify that there are no deadlocks when we send CCLs to the // first tunnel (forward path). - std::vector ring_devices = mesh->get_devices_on_row(0); // Tunnel 0 - std::vector ring_devices_1 = mesh->get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks + auto view = MeshDeviceView(*mesh); + std::vector ring_devices = view.get_devices_on_row(0); // Tunnel 0 + std::vector ring_devices_1 = view.get_devices_on_column(mesh_shape.second - 1); // Orthogonal to tunnel .. no deadlocks ring_devices_1 = std::vector(ring_devices_1.begin() + 1, ring_devices_1.end()); - std::vector ring_devices_2 = mesh->get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering + std::vector ring_devices_2 = view.get_devices_on_row(7); // Tunnel 7 .. potential deadlocks with lack of buffering std::reverse(ring_devices_2.begin(), ring_devices_2.end()); ring_devices_2 = std::vector(ring_devices_2.begin() + 1, ring_devices_2.end()); - std::vector ring_devices_3 = mesh->get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks + std::vector ring_devices_3 = view.get_devices_on_column(0); // Orthogonal to tunnel .. no deadlocks std::reverse(ring_devices_3.begin(), ring_devices_3.end()); ring_devices_3 = std::vector(ring_devices_3.begin() + 1, ring_devices_3.end() - 1); diff --git a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp index 9e779b7f0ccb..2b0a8fc04a1e 100644 --- a/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp +++ b/tests/ttnn/unit_tests/gtests/ttnn_test_fixtures.hpp @@ -69,11 +69,11 @@ class T3kMultiDeviceFixture : public ::testing::Test { } constexpr auto DEFAULT_NUM_COMMAND_QUEUES = 1; mesh_device_ = MeshDevice::create( - MeshShape{2, 4}, DEFAULT_L1_SMALL_SIZE, DEFAULT_TRACE_REGION_SIZE, DEFAULT_NUM_COMMAND_QUEUES, - DispatchCoreType::WORKER); + DispatchCoreType::WORKER, + MeshDeviceConfig(MeshShape{2, 4}, MeshType::Ring)); } void TearDown() override { diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index f1c2728857f0..fde9a224d719 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -587,3 +587,28 @@ def test_validate_as_tensor(tmp_path, mesh_device, height, width): for device in mesh_device.get_devices(): device_tensor = ttnn.get_device_tensor(tensor, device) assert torch.allclose(ttnn.to_torch(device_tensor), torch_input_tensor) + + +def test_visualize_mesh_device(t3k_mesh_device): + ttnn.visualize_mesh_device(t3k_mesh_device) + + +def test_all_gather_multiple_submeshes(t3k_mesh_device): + """Test all_gather with multiple submeshes""" + + def model(submesh): + full_tensor = torch.ones((1, 1, 32, 32 * submesh.get_num_devices()), dtype=torch.bfloat16) + for i in range(submesh.get_num_devices()): + full_tensor[..., i * 32 : (i + 1) * 32] = i + + ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(submesh, dim=3)) + ttnn_tensor = ttnn.to_device(ttnn_tensor, submesh) + ttnn_tensor = ttnn.all_gather(ttnn_tensor, dim=3, num_links=1) + + for device_tensor in ttnn.get_device_tensors(ttnn_tensor): + device_tensor_torch = ttnn.to_torch(device_tensor) + assert torch.all(device_tensor_torch == full_tensor) + + submesh_devices = t3k_mesh_device.create_submeshes((2, 2), ttnn.MeshType.Ring) + for submesh in submesh_devices: + model(submesh) diff --git a/tt_metal/impl/device/mesh_configurations/T3000.json b/tt_metal/impl/device/mesh_configurations/T3000.json index 2c62209d01fc..acfe3edac004 100644 --- a/tt_metal/impl/device/mesh_configurations/T3000.json +++ b/tt_metal/impl/device/mesh_configurations/T3000.json @@ -1,6 +1,6 @@ { "logical_to_physical_coordinates": [ [[0, 0], [0, 0, 0, 0]], [[0, 1], [0, 1, 0, 0]], [[0, 2], [0, 2, 0, 0]], [[0, 3], [0, 3, 0, 0]], - [[1, 0], [1, 3, 0, 0]], [[1, 1], [1, 2, 0, 0]], [[1, 2], [1, 1, 0, 0]], [[1, 3], [1, 0, 0, 0]] + [[1, 0], [1, 0, 0, 0]], [[1, 1], [1, 1, 0, 0]], [[1, 2], [1, 2, 0, 0]], [[1, 3], [1, 3, 0, 0]] ] } diff --git a/tt_metal/impl/device/mesh_device.cpp b/tt_metal/impl/device/mesh_device.cpp index e90d4a8925e2..0c8d169056c3 100644 --- a/tt_metal/impl/device/mesh_device.cpp +++ b/tt_metal/impl/device/mesh_device.cpp @@ -23,7 +23,7 @@ static std::string get_config_path(const std::string& filename) { return root_path + "/tt_metal/impl/device/mesh_configurations/" + filename; } -static std::map load_translation_map(const std::string& filename, const std::string& key) { +static std::unordered_map load_translation_map(const std::string& filename, const std::string& key) { std::ifstream file(filename); if (!file.is_open()) { throw std::runtime_error("Unable to open file: " + filename); @@ -40,7 +40,7 @@ static std::map load_translation_map(cons throw std::runtime_error("Key '" + key + "' not found in JSON file: " + filename); } - std::map result; + std::unordered_map result; for (const auto& mapping : j[key]) { if (mapping.size() != 2 || mapping[0].size() != 2 || mapping[1].size() != 4) { throw std::runtime_error("Invalid coordinate format in JSON file: " + filename); @@ -51,8 +51,8 @@ static std::map load_translation_map(cons return result; } -MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) { - const std::unordered_map system_mesh_to_shape = { +MeshShape SystemMesh::get_system_mesh_shape(size_t system_num_devices) { + const std::unordered_map system_mesh_to_shape = { {1, MeshShape{1, 1}}, // single-device {2, MeshShape{1, 2}}, // N300 {8, MeshShape{2, 4}}, // T3000; as ring to match existing tests @@ -65,8 +65,8 @@ MeshShape SystemMesh::get_system_mesh_shape(std::size_t system_num_devices) { return shape; } -std::map SystemMesh::get_system_mesh_translation_map(std::size_t system_num_devices) { - const std::unordered_map system_mesh_translation_map = { +std::unordered_map SystemMesh::get_system_mesh_translation_map(size_t system_num_devices) { + const std::unordered_map system_mesh_translation_map = { {1, "device.json"}, {2, "N300.json"}, {8, "T3000.json"}, @@ -102,7 +102,7 @@ void SystemMesh::initialize() { } const MeshShape& SystemMesh::get_shape() const { return this->logical_mesh_shape; } -std::size_t SystemMesh::get_num_devices() const { +size_t SystemMesh::get_num_devices() const { auto [num_rows, num_cols] = this->get_shape(); return num_rows * num_cols; } @@ -113,20 +113,41 @@ std::vector SystemMesh::get_mapped_physical_device_ids(const MeshDevi auto [requested_rows, requested_cols] = config.mesh_shape; auto [row_offset, col_offset] = config.offset; - for (int row = 0; row < requested_rows; row++) { - for (int col = 0; col < requested_cols; col++) { - auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); - auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + if (requested_rows == 1) { + TT_FATAL(row_offset == 0 and col_offset == 0, "Row and column offsets unsupported for single row mesh"); + auto line_coords = MeshDeviceView::get_line_coordinates(requested_cols, Coordinate{row_offset, col_offset}, system_mesh_rows, system_mesh_cols); + for (const auto& logical_coordinate : line_coords) { auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); physical_device_ids.push_back(physical_device_id); - log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", - logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); + log_debug(LogMetal, "Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_coordinate, physical_coordinate, physical_device_id); + } + } else { + for (int row = 0; row < requested_rows; row++) { + for (int col = 0; col < requested_cols; col++) { + auto logical_device_id = (row + row_offset) * system_mesh_cols + (col + col_offset); + auto logical_coordinate = Coordinate{logical_device_id / system_mesh_cols, logical_device_id % system_mesh_cols}; + auto physical_coordinate = this->logical_to_physical_coordinates.at(logical_coordinate); + auto physical_device_id = this->physical_coordinate_to_device_id.at(physical_coordinate); + physical_device_ids.push_back(physical_device_id); + + log_debug(LogMetal, "Logical device ID: {}, Logical coordinate: {}, Physical coordinate: {}, Physical device ID: {}", + logical_device_id, logical_coordinate, physical_coordinate, physical_device_id); + } } } return physical_device_ids; } +void SystemMesh::register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices) { + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); + this->assigned_devices.insert({mesh_device->get_mesh_id(), physical_device_ids}); +} std::vector SystemMesh::map_mesh_device( std::shared_ptr mesh_device, @@ -134,22 +155,20 @@ std::vector SystemMesh::map_mesh_device( size_t l1_small_size, size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) { + const MeshDeviceConfig& config) { auto [requested_num_rows, requested_num_cols] = mesh_device->shape(); auto [max_num_rows, max_num_cols] = this->logical_mesh_shape; - auto [row_offset, col_offset] = offset; + auto [row_offset, col_offset] = config.offset; log_debug(LogMetal, "Mapping MeshDevice ({}x{}) with offset: {}, {}", requested_num_rows, requested_num_cols, row_offset, col_offset); TT_FATAL(requested_num_rows <= max_num_rows, "Requested too many rows: {} > {}", requested_num_rows, max_num_rows); TT_FATAL(requested_num_rows*requested_num_cols <= max_num_rows*max_num_cols, "Requested submesh is too big: {}x{}", requested_num_rows, requested_num_cols); - this->assigned_mesh_device_devices.insert({mesh_device->get_mesh_id(), mesh_device}); - auto physical_device_ids = user_provided_physical_device_ids.empty() ? - this->get_mapped_physical_device_ids(MeshDeviceConfig{mesh_device->shape(), offset}) : - user_provided_physical_device_ids; + auto physical_device_ids = config.physical_device_ids.empty() ? + this->get_mapped_physical_device_ids(config) : + config.physical_device_ids; this->opened_devices[mesh_device->get_mesh_id()] = tt::tt_metal::detail::CreateDevices( physical_device_ids, num_command_queues, l1_small_size, trace_region_size, dispatch_core_type); @@ -158,27 +177,34 @@ std::vector SystemMesh::map_mesh_device( for (auto physical_device_id : physical_device_ids) { auto mapped_device = this->opened_devices[mesh_device->get_mesh_id()].at(physical_device_id); mapped_devices.push_back(mapped_device); - this->assigned_devices[mesh_device->get_mesh_id()].push_back(physical_device_id); this->assigned_physical_id_to_device.insert({physical_device_id, mapped_device}); } + + this->register_mesh_device(mesh_device, mapped_devices); // here return mapped_devices; } -void SystemMesh::unmap_mesh_device(const std::shared_ptr& mesh_device) { +void SystemMesh::unmap_mesh_device(const MeshDevice* mesh_device) { auto mesh_id = mesh_device->get_mesh_id(); - - // Clean up all state related to this virtual mesh this->assigned_mesh_device_devices.erase(mesh_id); - // Remove the devices from assigned_physical_id_to_device - for (auto physical_id : this->assigned_devices.at(mesh_id)) { - this->assigned_physical_id_to_device.erase(physical_id); + // Close the devices + if (mesh_device->is_parent_mesh()) { + for (auto physical_id : this->assigned_devices.at(mesh_id)) { + this->assigned_physical_id_to_device.erase(physical_id); + } + tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); + this->opened_devices.erase(mesh_id); } this->assigned_devices.erase(mesh_id); +} - // Close the devices - tt::tt_metal::detail::CloseDevices(this->opened_devices.at(mesh_id)); - this->opened_devices.erase(mesh_id); +Device* SystemMesh::get_device(const chip_id_t physical_device_id) const { + auto it = this->assigned_physical_id_to_device.find(physical_device_id); + if (it == this->assigned_physical_id_to_device.end()) { + TT_THROW("Physical Device ID: {} not found in assigned devices", physical_device_id); + } + return it->second; } static MeshDeviceID generate_unique_mesh_id() { @@ -186,30 +212,76 @@ static MeshDeviceID generate_unique_mesh_id() { return next_id++; } -MeshDevice::MeshDevice(const MeshShape& mesh_device_shape) : mesh_device_shape(mesh_device_shape), mesh_id(generate_unique_mesh_id()) {} +MeshDevice::MeshDevice(const MeshShape& mesh_device_shape, MeshType type, std::weak_ptr parent_mesh) + : mesh_device_shape(mesh_device_shape), type(type), mesh_id(generate_unique_mesh_id()), parent_mesh(parent_mesh) {} std::shared_ptr MeshDevice::create( - const MeshShape& mesh_device_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& user_provided_physical_device_ids) + const MeshDeviceConfig& config) { - auto mesh_device = std::make_shared(mesh_device_shape); - mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset, user_provided_physical_device_ids); + auto mesh_device = std::make_shared(config.mesh_shape, config.mesh_type); + mesh_device->initialize(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); return mesh_device; } +std::shared_ptr MeshDevice::create_submesh( + const MeshShape &submesh_shape, + const MeshOffset &offset, + MeshType type) +{ + if (submesh_shape.first <= 0 || submesh_shape.second <= 0) { + TT_THROW("Invalid submesh shape: ({}, {}). Both dimensions must be positive.", submesh_shape.first, submesh_shape.second); + } + + if (offset.first < 0 || offset.second < 0) { + TT_THROW("Invalid offset: ({}, {}). Offset must be non-negative.", offset.first, offset.second); + } + + if (offset.first + submesh_shape.first > this->mesh_device_shape.first || + offset.second + submesh_shape.second > this->mesh_device_shape.second) { + TT_THROW("Submesh ({}x{}) with offset ({}, {}) does not fit within parent mesh ({}x{}).", + submesh_shape.first, submesh_shape.second, + offset.first, offset.second, + this->mesh_device_shape.first, this->mesh_device_shape.second); + } + + auto submesh = std::make_shared(submesh_shape, type, shared_from_this()); + auto start_coordinate = Coordinate{offset.first, offset.second}; + auto end_coordinate = Coordinate{offset.first + submesh_shape.first - 1, offset.second + submesh_shape.second - 1}; + submesh->primary_view = std::make_shared(*this, start_coordinate, end_coordinate); + submesh->devices = submesh->primary_view->get_devices(); + SystemMesh::instance().register_mesh_device(submesh, submesh->devices); + this->submeshes.push_back(submesh); + log_trace(LogMetal, "Instantiating submesh {}: {}x{} with offset: {} {}", submesh->get_mesh_id(), submesh_shape.first, submesh_shape.second, offset.first, offset.second); + log_trace(LogMetal, "Submesh {} instantiated with {} devices", submesh->get_mesh_id(), submesh->devices); + + return submesh; +} + +std::vector> MeshDevice::create_submeshes( + const MeshShape &submesh_shape, + MeshType type) +{ + std::vector> submeshes; + for (int row = 0; row < this->num_rows(); row += submesh_shape.first) { + for (int col = 0; col < this->num_cols(); col += submesh_shape.second) { + auto submesh = this->create_submesh(submesh_shape, MeshOffset{row, col}, type); + submeshes.push_back(submesh); + } + } + return submeshes; +} + void MeshDevice::initialize( size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair& offset, - const std::vector& physical_device_ids) + const MeshDeviceConfig& config) { auto [num_rows, num_cols] = this->shape(); auto num_requested_devices = num_rows * num_cols; @@ -221,43 +293,29 @@ void MeshDevice::initialize( auto& instance = SystemMesh::instance(); this->devices = instance.map_mesh_device( - shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, offset, physical_device_ids); - this->primary_view = std::make_unique(*this); - - for (int device_index = 0; device_index < this->devices.size(); device_index++) { - this->physical_id_to_device_index.insert({this->devices[device_index]->id(), device_index}); - } + shared_from_this(), num_command_queues, l1_small_size, trace_region_size, dispatch_core_type, config); + this->primary_view = std::make_shared(*this); } MeshDevice::~MeshDevice() { - if (not this->devices.empty()) { - this->close_devices(); - } + close_devices(); } -Device* MeshDevice::get_device_index(int logical_device_id) const { +Device* MeshDevice::get_device_index(size_t logical_device_id) const { TT_FATAL(logical_device_id >= 0 and logical_device_id < num_devices(), "Invalid device index"); return this->devices.at(logical_device_id); } -Device* MeshDevice::get_device(int physical_device_id) const { - return this->devices.at(this->physical_id_to_device_index.at(physical_device_id)); +Device* MeshDevice::get_device(chip_id_t physical_device_id) const { + return SystemMesh::instance().get_device(physical_device_id); } -std::vector MeshDevice::get_devices() const { return this->devices; } +std::vector MeshDevice::get_devices() const { return this->primary_view->get_devices(this->type); } -Device* MeshDevice::get_device(int row_idx, int col_idx) const { +Device* MeshDevice::get_device(size_t row_idx, size_t col_idx) const { return this->get_device_index(row_idx * num_cols() + col_idx); } -std::vector MeshDevice::get_devices_on_row(int row_idx) const { - return this->primary_view->get_devices_on_row(row_idx); -} - -std::vector MeshDevice::get_devices_on_column(int col_idx) const { - return this->primary_view->get_devices_on_column(col_idx); -} - const DeviceIds MeshDevice::get_device_ids() const { DeviceIds device_ids; for (auto device : this->get_devices()) { @@ -266,7 +324,7 @@ const DeviceIds MeshDevice::get_device_ids() const { return device_ids; } -int MeshDevice::num_devices() const { return num_rows() * num_cols(); } +size_t MeshDevice::num_devices() const { return this->devices.size(); } CoreCoord MeshDevice::compute_with_storage_grid_size() const { return get_device_index(0)->compute_with_storage_grid_size(); } @@ -274,16 +332,21 @@ CoreCoord MeshDevice::dram_grid_size() const { return get_device_index(0)->dram_ tt::ARCH MeshDevice::arch() const { return get_device_index(0)->arch(); } -int MeshDevice::num_rows() const { return this->mesh_device_shape.first; } +size_t MeshDevice::num_rows() const { return this->mesh_device_shape.first; } -int MeshDevice::num_cols() const { return this->mesh_device_shape.second; } +size_t MeshDevice::num_cols() const { return this->mesh_device_shape.second; } MeshShape MeshDevice::shape() const { return this->mesh_device_shape; } void MeshDevice::close_devices() { - SystemMesh::instance().unmap_mesh_device(shared_from_this()); + for (auto submesh : this->submeshes) { + submesh->close_devices(); + } + if (not this->devices.empty()) { + SystemMesh::instance().unmap_mesh_device(this); + } + this->parent_mesh.reset(); this->devices.clear(); - this->physical_id_to_device_index.clear(); this->primary_view.reset(); } @@ -297,6 +360,36 @@ std::shared_ptr MeshDevice::get_view() { return this->primary_vi MeshDeviceID MeshDevice::get_mesh_id() const { return this->mesh_id; } +bool MeshDevice::is_parent_mesh() const { return this->parent_mesh.expired(); } + +std::shared_ptr SystemMesh::get_mesh_device(const std::vector& physical_device_ids) { + log_trace(LogMetal, "Getting mesh device for {} physical devices: {}", physical_device_ids.size(), physical_device_ids); + std::unordered_set input_set(physical_device_ids.begin(), physical_device_ids.end()); + + for (const auto& [mesh_id, mesh_device] : this->assigned_mesh_device_devices) { + const auto& assigned_devices = this->assigned_devices.at(mesh_id); + std::unordered_set assigned_set(assigned_devices.begin(), assigned_devices.end()); + log_trace(LogMetal, "Assigned devices: {}", assigned_devices); + + if (input_set == assigned_set) { + return mesh_device; + } + } + TT_THROW("No mesh device found for the provided devices"); +} + +std::shared_ptr MeshDevice::fetch_mesh_device(const std::vector& devices) { + TT_FATAL(devices.size() > 0, "No devices provided"); + auto& instance = SystemMesh::instance(); + std::vector physical_device_ids; + for (auto device : devices) { + physical_device_ids.push_back(device->id()); + } + return instance.get_mesh_device(physical_device_ids); +} + +std::vector> MeshDevice::get_submeshes() const { return this->submeshes; } + std::ostream& operator<<(std::ostream& os, const MeshDevice& mesh_device) { return os << mesh_device.to_string(); } bool validate_worker_modes(const std::vector& workers) { @@ -313,7 +406,8 @@ std::vector get_t3k_physical_device_ids_ring() { auto num_devices = instance.get_num_devices(); TT_FATAL(num_devices == 8, "T3000 ring topology only works with 8 devices"); - auto physical_device_ids = instance.get_mapped_physical_device_ids(MeshDeviceConfig{instance.get_shape(), MeshOffset{0, 0}}); + auto physical_device_ids = instance.get_mapped_physical_device_ids( + MeshDeviceConfig(MeshShape{1, 8}, MeshOffset{0, 0})); return physical_device_ids; } diff --git a/tt_metal/impl/device/mesh_device.hpp b/tt_metal/impl/device/mesh_device.hpp index 940110973cce..1f3cf43592fb 100644 --- a/tt_metal/impl/device/mesh_device.hpp +++ b/tt_metal/impl/device/mesh_device.hpp @@ -9,20 +9,39 @@ #include #include -#include "mesh_device_view.hpp" #include "tt_metal/impl/device/device.hpp" #include "tt_metal/impl/device/mesh_device_view.hpp" namespace tt::tt_metal { using DeviceIds = std::vector; -using MeshDeviceID = std::size_t; +using MeshDeviceID = size_t; using MeshOffset = std::pair; class MeshDeviceView; struct MeshDeviceConfig { MeshShape mesh_shape; MeshOffset offset; + std::vector physical_device_ids; + MeshType mesh_type; + + MeshDeviceConfig( + const MeshShape &mesh_shape, + MeshType mesh_type = MeshType::RowMajor) : + mesh_shape(mesh_shape), + offset(MeshOffset{0, 0}), + physical_device_ids(std::vector()), + mesh_type(mesh_type) {} + + MeshDeviceConfig( + const MeshShape &mesh_shape, + const MeshOffset &offset = MeshOffset{0, 0}, + const std::vector &physical_device_ids = {}, + MeshType mesh_type = MeshType::RowMajor) : + mesh_shape(mesh_shape), + offset(offset), + physical_device_ids(physical_device_ids), + mesh_type(mesh_type) {} }; // SystemMesh creates a virtualization over the physical devices in the system. @@ -43,7 +62,7 @@ class SystemMesh { // Logical mesh shape and coordinates MeshShape logical_mesh_shape; - std::map logical_to_physical_coordinates; + std::unordered_map logical_to_physical_coordinates; // Handling of physical coordinates std::unordered_map physical_coordinate_to_device_id; @@ -55,9 +74,9 @@ class SystemMesh { SystemMesh(SystemMesh &&) = delete; SystemMesh &operator=(SystemMesh &&) = delete; - static MeshShape get_system_mesh_shape(std::size_t system_num_devices); - static std::map get_system_mesh_translation_map( - std::size_t system_num_devices); + static MeshShape get_system_mesh_shape(size_t system_num_devices); + static std::unordered_map get_system_mesh_translation_map( + size_t system_num_devices); bool is_system_mesh_initialized() const; @@ -68,10 +87,11 @@ class SystemMesh { // Return the shape of the logical mesh const MeshShape &get_shape() const; - std::size_t get_num_devices() const; + size_t get_num_devices() const; // Get the physical device IDs mapped to a MeshDevice std::vector get_mapped_physical_device_ids(const MeshDeviceConfig &config) const; + void register_mesh_device(const std::shared_ptr &mesh_device, const std::vector& devices); // Map MeshDevice to physical devices std::vector map_mesh_device( @@ -80,30 +100,33 @@ class SystemMesh { size_t l1_small_size, size_t trace_region_size, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); + const MeshDeviceConfig &config); // Unmap MeshDevice, releasing the associated physical devices. - void unmap_mesh_device(const std::shared_ptr &mesh_device); + void unmap_mesh_device(const MeshDevice* mesh_device); + std::shared_ptr get_mesh_device(const std::vector& physical_device_ids); + Device* get_device(const chip_id_t physical_device_id) const; }; class MeshDevice : public std::enable_shared_from_this { + private: MeshDeviceID mesh_id; MeshShape mesh_device_shape; + MeshType type; std::shared_ptr primary_view; std::vector devices; - std::unordered_map physical_id_to_device_index; + std::vector> submeshes; // Parent owns submeshes and responsible fortheir destruction + std::weak_ptr parent_mesh; // Submesh created with reference to parent mesh void initialize( size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset, - const std::vector &physical_device_ids); + const MeshDeviceConfig &config); public: - MeshDevice(const MeshShape &mesh_device_shape); + MeshDevice(const MeshShape &mesh_device_shape, MeshType type, std::weak_ptr parent_mesh = {}); ~MeshDevice(); MeshDevice(const MeshDevice &) = delete; @@ -113,17 +136,15 @@ class MeshDevice : public std::enable_shared_from_this { MeshDevice &operator=(MeshDevice &&) = delete; std::vector get_devices() const; - Device *get_device_index(int logical_device_id) const; - Device *get_device(int physical_device_id) const; - Device *get_device(int row_idx, int col_idx) const; - std::vector get_devices_on_row(int row_idx) const; - std::vector get_devices_on_column(int col_idx) const; + Device *get_device_index(size_t logical_device_id) const; + Device *get_device(chip_id_t physical_device_id) const; + Device *get_device(size_t row_idx, size_t col_idx) const; const DeviceIds get_device_ids() const; - int num_devices() const; - int num_rows() const; - int num_cols() const; + size_t num_devices() const; + size_t num_rows() const; + size_t num_cols() const; MeshShape shape() const; CoreCoord compute_with_storage_grid_size() const; @@ -138,15 +159,26 @@ class MeshDevice : public std::enable_shared_from_this { std::string to_string() const; MeshDeviceID get_mesh_id() const; + bool is_parent_mesh() const; + std::vector> get_submeshes() const; + + std::shared_ptr create_submesh( + const MeshShape &submesh_shape, + const MeshOffset &offset = MeshOffset{0, 0}, + MeshType type = MeshType::RowMajor); + + std::vector> create_submeshes( + const MeshShape &submesh_shape, + MeshType type = MeshType::RowMajor); + + static std::shared_ptr fetch_mesh_device(const std::vector& devices); static std::shared_ptr create( - const MeshShape &mesh_device_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, - const std::pair &offset = {0, 0}, - const std::vector &physical_device_ids = {}); + const MeshDeviceConfig &config); }; std::ostream &operator<<(std::ostream &os, const MeshDevice &mesh_device); diff --git a/tt_metal/impl/device/mesh_device_view.cpp b/tt_metal/impl/device/mesh_device_view.cpp index cc4a227780f6..48d8e151549c 100644 --- a/tt_metal/impl/device/mesh_device_view.cpp +++ b/tt_metal/impl/device/mesh_device_view.cpp @@ -3,14 +3,26 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_metal/impl/device/mesh_device_view.hpp" -#include "tt_metal/impl/device/mesh_device.hpp" + #include #include +#include "tt_metal/impl/device/mesh_device.hpp" + namespace tt::tt_metal { using MeshDevice = tt::tt_metal::MeshDevice; +static std::vector get_devices_from_coordinates(MeshDeviceView& mesh, const std::vector& coords) { + std::vector devices; + for (const auto& coord : coords) { + if (auto device = mesh.get_device(coord.row, coord.col)) { + devices.push_back(device); + } + } + return devices; +} + MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) : top_left_(0, 0), bottom_right_(mesh.num_rows() - 1, mesh.num_cols() - 1) { for (size_t row = 0; row < mesh.num_rows(); ++row) { @@ -24,12 +36,12 @@ MeshDeviceView::MeshDeviceView(const MeshDevice& mesh) } MeshDeviceView::MeshDeviceView(const MeshDevice& mesh, Coordinate top_left, Coordinate bottom_right) - : top_left_(top_left), bottom_right_(bottom_right) { + : top_left_(0, 0), bottom_right_(Coordinate{bottom_right.row - top_left.row, bottom_right.col - top_left.col}) { for (size_t row = top_left.row; row <= bottom_right.row; ++row) { for (size_t col = top_left.col; col <= bottom_right.col; ++col) { if (auto device = mesh.get_device(row, col)) { devices_.push_back(device); - device_coordinates_[(device)->id()] = {row, col}; + device_coordinates_[(device)->id()] = {row - top_left.row, col - top_left.col}; } } } @@ -55,10 +67,6 @@ MeshDeviceView::const_device_pointer MeshDeviceView::get_device(size_t row, size return nullptr; } -const std::vector& MeshDeviceView::get_devices() const { - return devices_; -} - MeshDeviceView::DeviceView MeshDeviceView::get_devices(const Coordinate& start, const Coordinate& end) { if (start.row > end.row || start.col > end.col) { log_fatal("Invalid coordinates: start {} must be less than or equal to end {}", start, end); @@ -117,16 +125,6 @@ std::vector> MeshDeviceView::get_col return column_views; } -template -MeshDeviceView MeshDeviceView::subview(Pred&& predicate) const { - std::vector filtered_devices; - std::copy_if(devices_.begin(), devices_.end(), std::back_inserter(filtered_devices), std::forward(predicate)); - return MeshDeviceView(filtered_devices, [this](int device_id) { - auto it = device_coordinates_.find(device_id); - return it != device_coordinates_.end() ? std::optional(it->second) : std::nullopt; - }); -} - bool MeshDeviceView::empty() const noexcept { return devices_.empty(); } @@ -158,6 +156,10 @@ bool MeshDeviceView::operator==(const MeshDeviceView& other) const { bottom_right_ == other.bottom_right_; } +bool MeshDeviceView::contains_device(chip_id_t device_id) const { + return device_coordinates_.find(device_id) != device_coordinates_.end(); +} + Coordinate MeshDeviceView::find_device(chip_id_t device_id) const { auto it = device_coordinates_.find(device_id); if (it != device_coordinates_.end()) { @@ -174,8 +176,8 @@ chip_id_t MeshDeviceView::find_device_id(const Coordinate& coord) const { } void MeshDeviceView::initialize_from_devices(const std::vector& devices, CoordinateMapper mapper) { - std::size_t min_row = std::numeric_limits::max(), min_col = std::numeric_limits::max(); - std::size_t max_row = std::numeric_limits::min(), max_col = std::numeric_limits::min(); + size_t min_row = std::numeric_limits::max(), min_col = std::numeric_limits::max(); + size_t max_row = std::numeric_limits::min(), max_col = std::numeric_limits::min(); for (const auto& device : devices) { auto coord = mapper(device->id()); @@ -194,10 +196,95 @@ void MeshDeviceView::initialize_from_devices(const std::vector& bottom_right_ = {max_row, max_col}; } +std::vector MeshDeviceView::get_line_coordinates( + size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols) { + std::vector line_coords; + auto [row_index, col_index] = offset; + bool left_to_right = true; + + for (size_t i = 0; i < length && row_index < num_rows && col_index < num_cols; ++i) { + line_coords.emplace_back(Coordinate{row_index, col_index}); + + if (left_to_right && col_index < num_cols - 1) { + col_index++; + } else if (!left_to_right && col_index > 0) { + col_index--; + } else { + row_index++; + left_to_right = !left_to_right; + } + } + + TT_FATAL(line_coords.size() == length, "Failed to get line coordinates"); + return line_coords; +} + +std::vector MeshDeviceView::get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols) { + auto [start_row, start_col] = offset; + auto [ring_rows, ring_cols] = ring_shape; + auto end_row = start_row + ring_rows - 1; + auto end_col = start_col + ring_cols - 1; + + // Validate the specified subgrid + std::vector boundary_coords; + if (start_row + ring_rows > num_rows || start_col + ring_cols > num_cols) { + throw std::invalid_argument("Subgrid is out of mesh bounds."); + } + + // Traverse the top row from left to right + for (size_t col = start_col; col <= end_col; ++col) { + boundary_coords.emplace_back(Coordinate{start_row, col}); + } + + // Traverse the rightmost column from top+1 to bottom + for (size_t row = start_row + 1; row <= end_row; ++row) { + boundary_coords.emplace_back(Coordinate{row, end_col}); + } + + // Traverse the bottom row from right to left, if there is more than one row + if (ring_rows > 1 and ring_cols > 1) { + // Traverse the bottom row from right to left + for (int col = static_cast(end_col - 1); col >= static_cast(start_col); --col) { + boundary_coords.emplace_back(Coordinate{end_row, static_cast(col)}); + } + + // Traverse the leftmost column from bottom-1 to top+1 + for (int row = static_cast(end_row - 1); row > static_cast(start_row); --row) { + boundary_coords.emplace_back(Coordinate{static_cast(row), start_col}); + } + } + + return boundary_coords; +} + + void MeshDeviceView::validate_coordinates() const { if (top_left_.row > bottom_right_.row || top_left_.col > bottom_right_.col) { throw std::invalid_argument("Invalid coordinates: top_left must be less than or equal to bottom_right"); } } -} // namespace tt::tt_metal +std::vector MeshDeviceView::get_line_devices() { + auto boundary_coords = get_line_coordinates(this->num_rows() * this->num_cols(), this->top_left_, this->num_rows(), this->num_cols()); + return get_devices_from_coordinates(*this, boundary_coords); +} + +std::vector MeshDeviceView::get_ring_devices() { + auto boundary_coords = get_ring_coordinates(shape(), this->top_left_, this->num_rows(), this->num_cols()); + return get_devices_from_coordinates(*this, boundary_coords); +} + +MeshDeviceView::DeviceView MeshDeviceView::get_devices(MeshType type) { + switch (type) { + case MeshType::RowMajor: + return this->devices_; + case MeshType::Ring: + return this->get_ring_devices(); + case MeshType::Line: + return this->get_line_devices(); + default: + TT_THROW("Unsupported Mesh type: {}", type); + } +} + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/device/mesh_device_view.hpp b/tt_metal/impl/device/mesh_device_view.hpp index 73c9e2b61c20..2b16a2652b95 100644 --- a/tt_metal/impl/device/mesh_device_view.hpp +++ b/tt_metal/impl/device/mesh_device_view.hpp @@ -20,12 +20,12 @@ class MeshDevice; using MeshShape = std::pair; struct Coordinate { - std::size_t row; - std::size_t col; + size_t row; + size_t col; auto operator<=>(const Coordinate&) const = default; // Add support for structured bindings - template + template decltype(auto) get() const { if constexpr (I == 0) return row; else if constexpr (I == 1) return col; @@ -53,6 +53,13 @@ struct Coordinate { * specific sub-regions. This is particularly useful for collective communication operations * (CCL-ops), such as line all-gather, which require column or row views of the device mesh. */ + +enum class MeshType { + RowMajor, + Ring, + Line +}; + class MeshDeviceView { public: using device_pointer = Device*; @@ -68,12 +75,11 @@ class MeshDeviceView { [[nodiscard]] device_pointer get_device(size_t row, size_t col); [[nodiscard]] const_device_pointer get_device(size_t row, size_t col) const; - [[nodiscard]] const std::vector& get_devices() const; - // Get devices spanning the rectangular region defined by the top-left and bottom-right coordinates // devices are returned in row-major order with start/end coordinates inclusive [[nodiscard]] DeviceView get_devices(const Coordinate& start, const Coordinate& end); [[nodiscard]] DeviceView get_devices(const MeshShape& shape); + [[nodiscard]] DeviceView get_devices(MeshType type = MeshType::RowMajor); [[nodiscard]] DeviceView get_devices_on_row(size_t row) const; [[nodiscard]] DeviceView get_devices_on_column(size_t col) const; @@ -81,12 +87,9 @@ class MeshDeviceView { [[nodiscard]] DeviceViews get_row_views() const; [[nodiscard]] DeviceViews get_column_views() const; - template - [[nodiscard]] MeshDeviceView subview(Pred&& predicate) const; - [[nodiscard]] bool empty() const noexcept; [[nodiscard]] size_t size() const noexcept; - [[nodiscard]] std::pair shape() const noexcept; + [[nodiscard]] MeshShape shape() const noexcept; [[nodiscard]] bool contains(const Coordinate& coord) const noexcept; [[nodiscard]] const_device_pointer at(const Coordinate& coord) const noexcept; @@ -95,13 +98,21 @@ class MeshDeviceView { auto begin() const { return devices_.begin(); } auto end() const { return devices_.end(); } - [[nodiscard]] std::size_t num_rows() const { return bottom_right_.row - top_left_.row + 1; } - [[nodiscard]] std::size_t num_cols() const { return bottom_right_.col - top_left_.col + 1; } - [[nodiscard]] std::size_t num_devices() const { return devices_.size(); } + [[nodiscard]] size_t num_rows() const { return bottom_right_.row - top_left_.row + 1; } + [[nodiscard]] size_t num_cols() const { return bottom_right_.col - top_left_.col + 1; } + [[nodiscard]] size_t num_devices() const { return devices_.size(); } + [[nodiscard]] bool contains_device(chip_id_t device_id) const; [[nodiscard]] Coordinate find_device(chip_id_t device_id) const; [[nodiscard]] chip_id_t find_device_id(const Coordinate& coord) const; + // Given a starting coordinate, get the coordinates of a line of devices where device[i-1] is connected to device[i] + // The current support only provides left-to-right and right-to-left snaking of the line. + [[nodiscard]] static std::vector get_line_coordinates(size_t length, const Coordinate& offset, size_t num_rows, size_t num_cols); + [[nodiscard]] std::vector get_ring_coordinates(const MeshShape& ring_shape, const Coordinate& offset, size_t num_rows, size_t num_cols); + [[nodiscard]] std::vector get_ring_devices(); + [[nodiscard]] std::vector get_line_devices(); + private: std::vector devices_; std::unordered_map device_coordinates_; @@ -119,10 +130,21 @@ inline MeshDeviceView make_mesh_device_view(std::vector devices, MeshDe } // namespace tt::tt_metal -// Specializations to enable structured bindings namespace std { - template<> struct tuple_size : std::integral_constant {}; - template struct tuple_element { - using type = std::size_t; + // Specializations to enable structured bindings + template<> struct tuple_size : std::integral_constant {}; + template struct tuple_element { + using type = size_t; + }; + + // Specialization to enable hashing of Coordinate + template <> + struct hash { + size_t operator()(const tt::tt_metal::Coordinate& coord) const noexcept { + size_t seed = 0; + tt::utils::hash_combine(seed, coord.row); + tt::utils::hash_combine(seed, coord.col); + return seed; + } }; } // namespace std diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index d7a12ee9eba4..05709f423aa7 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit d7a12ee9eba4e17158d7e6731ce09d48be90eac4 +Subproject commit 05709f423aa713fd299f52f4779d09e791a3228e diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 3457491ab21a..47bc7d232edd 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 3457491ab21aecd4325851c2607c35582f89e111 +Subproject commit 47bc7d232edd7d7974938ec539a5661e689f5b53 diff --git a/ttnn/cpp/pybind11/multi_device.hpp b/ttnn/cpp/pybind11/multi_device.hpp index 70d9755d0400..c9c661e04af0 100644 --- a/ttnn/cpp/pybind11/multi_device.hpp +++ b/ttnn/cpp/pybind11/multi_device.hpp @@ -16,9 +16,16 @@ namespace ttnn { namespace multi_device { -void py_module_types(py::module& module) { py::class_>(module, "MeshDevice"); } +void py_module_types(py::module& module) { + py::class_>(module, "MeshDevice"); +} void py_module(py::module& module) { + py::enum_(module, "MeshType") + .value("RowMajor", MeshType::RowMajor) + .value("Ring", MeshType::Ring) + .value("Line", MeshType::Line) + .export_values(); auto py_mesh_device = static_cast>>(module.attr("MeshDevice")); py_mesh_device .def( @@ -28,15 +35,15 @@ void py_module(py::module& module) { size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset, - const std::vector& physical_device_ids) { + const std::vector& physical_device_ids, + MeshType mesh_type) { + auto config = MeshDeviceConfig(mesh_device_shape, offset, physical_device_ids, mesh_type); return MeshDevice::create( - mesh_device_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, - offset, - physical_device_ids); + config); }), py::kw_only(), py::arg("mesh_shape"), @@ -45,16 +52,18 @@ void py_module(py::module& module) { py::arg("num_command_queues"), py::arg("dispatch_core_type"), py::arg("offset"), - py::arg("physical_device_ids")) + py::arg("physical_device_ids"), + py::arg("mesh_type")) .def("get_num_devices", &MeshDevice::num_devices) + .def("get_mesh_id", &MeshDevice::get_mesh_id) .def("get_device_ids", &MeshDevice::get_device_ids) .def( "get_device", - py::overload_cast(&MeshDevice::get_device, py::const_), + py::overload_cast(&MeshDevice::get_device, py::const_), py::return_value_policy::reference) .def( "get_device", - py::overload_cast(&MeshDevice::get_device, py::const_), + py::overload_cast(&MeshDevice::get_device, py::const_), py::return_value_policy::reference) .def("get_devices", &MeshDevice::get_devices, py::return_value_policy::reference, R"doc( Get the devices in the device mesh. @@ -62,26 +71,12 @@ void py_module(py::module& module) { Returns: List[Device]: The devices in the device mesh. )doc") - .def( - "get_devices_on_row", - &MeshDevice::get_devices_on_row, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") - .def( - "get_devices_on_column", - &MeshDevice::get_devices_on_column, - py::return_value_policy::reference, - R"doc( - Get the devices in a row of the device mesh. - - Returns: - List[Device]: The devices on a row in the device mesh. - )doc") + .def("create_submesh", &MeshDevice::create_submesh, + py::arg("submesh_shape"), py::arg("offset"), py::arg("mesh_type"), + py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevice is alive + .def("create_submeshes", &MeshDevice::create_submeshes, + py::arg("submesh_shape"), py::arg("mesh_type"), + py::keep_alive<1, 0>()) // Keep MeshDevice alive as long as SubmeshDevices are alive .def( "compute_with_storage_grid_size", &MeshDevice::compute_with_storage_grid_size, @@ -126,7 +121,9 @@ void py_module(py::module& module) { py::arg("trace_region_size"), py::arg("num_command_queues"), py::arg("dispatch_core_type"), - py::arg("physical_device_ids")); + py::arg("offset"), + py::arg("physical_device_ids"), + py::arg("mesh_type")); module.def("close_mesh_device", &close_mesh_device, py::arg("mesh_device"), py::kw_only()); module.def( diff --git a/ttnn/cpp/ttnn/multi_device.cpp b/ttnn/cpp/ttnn/multi_device.cpp index 7fa5f9e0d650..b8a9e91a900b 100644 --- a/ttnn/cpp/ttnn/multi_device.cpp +++ b/ttnn/cpp/ttnn/multi_device.cpp @@ -12,8 +12,9 @@ namespace ttnn::multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset) { - return MeshDevice::create(mesh_shape, l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, offset); +std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, MeshType mesh_type, const std::pair& offset, const std::vector& physical_device_ids) { + auto config = MeshDeviceConfig(mesh_shape, offset, physical_device_ids, mesh_type); + return MeshDevice::create(l1_small_size, trace_region_size, num_command_queues, dispatch_core_type, config); } void close_mesh_device(const std::shared_ptr& mesh_device) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index d7db05721bc5..ecd95d659b73 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -6,16 +6,25 @@ #include -#include "ttnn/types.hpp" -#include "ttnn/tensor/tensor.hpp" #include "tt_metal/impl/device/mesh_device.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" using Device = ttnn::Device; namespace ttnn { namespace multi_device { -std::shared_ptr open_mesh_device(const MeshShape& mesh_shape, size_t l1_small_size, size_t trace_region_size, size_t num_command_queues, DispatchCoreType dispatch_core_type, const std::pair& offset = {0, 0}); +std::shared_ptr open_mesh_device( + const MeshShape& mesh_shape, + size_t l1_small_size, + size_t trace_region_size, + size_t num_command_queues, + DispatchCoreType dispatch_core_type, + MeshType mesh_type = MeshType::RowMajor, + const std::pair& offset = std::pair(0, 0), + const std::vector& physical_device_ids = {}); + void close_mesh_device(const std::shared_ptr& mesh_device); std::vector get_device_tensors(const ttnn::Tensor& tensor); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index ddd1e95725de..8dcad39de63b 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -197,6 +197,7 @@ def manage_config(name, value): visualize_mesh_device, ConcatMesh2dToTensor, distribute, + MeshType, ) from ttnn.core import ( diff --git a/ttnn/ttnn/multi_device.py b/ttnn/ttnn/multi_device.py index 87f3297dbee1..fa4abe3c86df 100644 --- a/ttnn/ttnn/multi_device.py +++ b/ttnn/ttnn/multi_device.py @@ -18,6 +18,7 @@ def get_mesh_device_core_grid(mesh_device): MeshDevice = ttnn._ttnn.multi_device.MeshDevice MeshDevice.core_grid = property(get_mesh_device_core_grid) DispatchCoreType = ttnn._ttnn.device.DispatchCoreType +MeshType = ttnn._ttnn.multi_device.MeshType def _get_rich_table( @@ -140,6 +141,7 @@ def open_mesh_device( dispatch_core_type: int = DispatchCoreType.WORKER, offset: Tuple[int, int] = (0, 0), physical_device_ids: List[int] = [], + mesh_type: "MeshType" = MeshType.RowMajor, ): """ Open a mesh device with the specified configuration. @@ -151,6 +153,7 @@ def open_mesh_device( num_command_queues (int, optional): Number of command queues. Defaults to 1. dispatch_core_type (int, optional): Type of dispatch core. Defaults to DispatchCoreType.WORKER. offset (Tuple[int, int], optional): Offset in logical mesh coordinates for the mesh device. Defaults to (0, 0). + mesh_type (MeshType, optional): Defines type of mesh requested. Type imposes connectivity constraints and defines device iteration order. Returns: ttnn._ttnn.multi_device.MeshDevice: The opened mesh device. @@ -164,6 +167,7 @@ def open_mesh_device( dispatch_core_type=dispatch_core_type, offset=offset, physical_device_ids=physical_device_ids, + mesh_type=mesh_type, )