From ed6dda9bd529d2a9a64a17c23cdb41360dd70478 Mon Sep 17 00:00:00 2001 From: Oleg Milyutin Date: Wed, 27 Nov 2024 15:23:53 -0500 Subject: [PATCH] #14974: ttnn::{full,empty}_like Tensor creation API for MeshDevice (#15333) ### Ticket #14974 ### Problem description Extensions to `Tensor` creation APIs to support `MeshDevice`. ### What's changed Follow up from https://github.com/tenstorrent/tt-metal/pull/15191: * Extend support for `MeshDevice` in `ttnn::empty_like`, `ttnn::full`, `ttnn::full_like`. * Minor refactor of `Tensor` allocation functions. ### Checklist - [X] [Post commit CI passes](https://github.com/tenstorrent/tt-metal/actions/runs/12054314234) (the PR was updated to address the static checks) - [X] New/Existing tests provide coverage for changes - [X] [T3K unit tests + frequent tests](https://github.com/tenstorrent/tt-metal/actions/runs/12054324195) (the failures are the same as on main) --- .../test_create_tensor_multi_device.cpp | 131 ++++++- .../unit_tests/operations/test_creation.py | 140 +++---- .../sources/ttml/core/tt_tensor_utils.cpp | 5 +- ttnn/CMakeLists.txt | 2 + ttnn/cpp/pybind11/operations/creation.hpp | 350 ++++++++---------- ttnn/cpp/ttnn/any_device.hpp | 40 ++ ttnn/cpp/ttnn/distributed/api.cpp | 40 +- ttnn/cpp/ttnn/operations/core/core.cpp | 9 +- ttnn/cpp/ttnn/operations/creation.cpp | 31 ++ ttnn/cpp/ttnn/operations/creation.hpp | 195 ++++++---- ttnn/cpp/ttnn/operations/numpy/functions.hpp | 119 ++---- ttnn/cpp/ttnn/tensor/tensor.cpp | 38 +- ttnn/cpp/ttnn/tensor/tensor.hpp | 14 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 13 +- 14 files changed, 645 insertions(+), 482 deletions(-) create mode 100644 ttnn/cpp/ttnn/any_device.hpp create mode 100644 ttnn/cpp/ttnn/operations/creation.cpp diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp index e92e583d59e..585326afc8b 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor_multi_device.cpp @@ -3,6 +3,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include #include #include "buffers/buffer_constants.hpp" @@ -24,11 +25,11 @@ using ::tt::tt_metal::TensorMemoryLayout; class MultiDeviceTensorCreationTest : public T3kMultiDeviceFixture, public ::testing::WithParamInterface {}; -TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) { +TEST_P(MultiDeviceTensorCreationTest, Empty) { MeshDevice* mesh_device = this->mesh_device_.get(); mesh_device->enable_async(GetParam()); - const auto mesh_replicated_tensor = ttnn::empty( + const Tensor mesh_replicated_tensor = ttnn::empty( ttnn::Shape(std::array{32, 32}), DataType::BFLOAT16, Layout::ROW_MAJOR, @@ -39,7 +40,133 @@ TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) { EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, EmptyLike) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + const Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + const Tensor mesh_replicated_tensor = ttnn::empty_like( + tensor, + DataType::BFLOAT16, + Layout::ROW_MAJOR, + *mesh_device, + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, Full) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + const Tensor mesh_replicated_tensor = ttnn::full( + ttnn::Shape(std::array{32, 32}), + /*fill_value=*/42, + DataType::BFLOAT16, + Layout::ROW_MAJOR, + std::ref(*mesh_device), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), ttnn::SimpleShape({32, 32})); + EXPECT_EQ(mesh_replicated_tensor.dtype(), DataType::BFLOAT16); + EXPECT_EQ(mesh_replicated_tensor.layout(), Layout::ROW_MAJOR); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, FullLike) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + Tensor mesh_replicated_tensor = ttnn::full_like( + tensor, + /*fill_value=*/42, + /*dtype=*/std::nullopt, + /*layout=*/std::nullopt, + std::ref(*mesh_device)); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); + EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); + EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); +} + +TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) { + MeshDevice* mesh_device = this->mesh_device_.get(); + mesh_device->enable_async(GetParam()); + + ASSERT_FALSE(mesh_device->get_devices().empty()); + + Tensor tensor = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device->get_devices().at(0), + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE); + EXPECT_EQ(tensor.get_workers().size(), 1); + + Tensor opt_output = ttnn::empty( + ttnn::Shape(std::array{32, 32}), + DataType::BFLOAT16, + Layout::ROW_MAJOR, + mesh_device, + MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt}); + + Tensor mesh_replicated_tensor = ttnn::full_like( + tensor, + /*fill_value=*/42, + /*dtype=*/std::nullopt, + /*layout=*/std::nullopt, + /*device=*/std::nullopt, + /*memory_config=*/std::nullopt, + opt_output); + + EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE); + EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices()); + EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape()); + EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype()); + EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout()); + + const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor); EXPECT_TRUE(std::holds_alternative(distributed_tensor_config)); } diff --git a/tests/ttnn/unit_tests/operations/test_creation.py b/tests/ttnn/unit_tests/operations/test_creation.py index 07f13d5708f..f6f6773dc81 100644 --- a/tests/ttnn/unit_tests/operations/test_creation.py +++ b/tests/ttnn/unit_tests/operations/test_creation.py @@ -32,62 +32,6 @@ def test_zeros_like(device, input_shape): assert torch.allclose(torch_output_tensor, output_tensor) -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -def test_zeros_like_bf8b(device, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.zeros_like(torch_input_tensor) - - input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) - output_tensor = ttnn.zeros_like(input_tensor) - assert ttnn.is_tensor_storage_on_device(output_tensor) - output_tensor = ttnn.from_device(output_tensor) - output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16) - - assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, output_tensor) - - -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -@pytest.mark.parametrize( - "layout", - [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], -) -def test_zeros_like_opt(device, layout, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.zeros_like(torch_input_tensor) - opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16) - opt_tensor = ttnn.from_torch( - opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG - ) - - input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout) - input_tensor = ttnn.to_device(input_tensor, device) - - cq_id = 0 - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.zeros_like(input_tensor, optional_tensor=opt_tensor, queue_id=cq_id) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - - assert ttnn.is_tensor_storage_on_device(opt_tensor) - opt_tensor = ttnn.from_device(opt_tensor) - opt_tensor = ttnn.to_torch(opt_tensor) - - assert_with_pcc(torch_output_tensor, opt_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, opt_tensor) - - @pytest.mark.parametrize( "input_shape", [ @@ -110,35 +54,13 @@ def test_ones_like(device, input_shape): assert torch.allclose(torch_output_tensor, output_tensor) -@pytest.mark.parametrize( - "input_shape", - [ - [32, 32], - [5, 96, 64], - ], -) -def test_ones_like_bf8b(device, input_shape): - torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) - torch_output_tensor = torch.ones_like(torch_input_tensor) - - input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device) - input_tensor = ttnn.to_device(input_tensor, device) - output_tensor = ttnn.ones_like(input_tensor) - assert ttnn.is_tensor_storage_on_device(output_tensor) - output_tensor = ttnn.from_device(output_tensor) - output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16) - - assert_with_pcc(torch_output_tensor, output_tensor, 0.9999) - assert torch.allclose(torch_output_tensor, output_tensor) - - @pytest.mark.parametrize( "input_shape", [[32, 32], [5, 96, 64], [1, 2, 64, 64], [1, 2, 4, 64, 64]], ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) def test_full_like(device, input_shape, fill_value): torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) @@ -161,7 +83,7 @@ def test_full_like(device, input_shape, fill_value): ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) def test_full_like_bf8b(device, input_shape, fill_value): torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16) @@ -187,7 +109,7 @@ def test_full_like_bf8b(device, input_shape, fill_value): ) @pytest.mark.parametrize( "fill_value", - [-5, 3, 15, 25], + [-5.25, 0, 1.0], ) @pytest.mark.parametrize( "layout", @@ -286,6 +208,7 @@ def test_full(device, input_shape, fill_value, layout): [ [32, 32], [5, 96, 64], + [1, 50257], ], ) @pytest.mark.parametrize( @@ -314,6 +237,34 @@ def test_full_with_opt_tensor(device, input_shape, layout, fill_value): assert torch.allclose(torch_tensor, opt_tensor) +@pytest.mark.parametrize( + "input_shape", + [ + [32, 32], + [5, 96, 64], + [1, 50257], + ], +) +@pytest.mark.parametrize( + "fill_value", + [-5.25, 0, 1.0], +) +@pytest.mark.parametrize( + "layout", + [ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE], +) +def test_full_multi_device(mesh_device, input_shape, fill_value, layout): + torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value) + + tensor = ttnn.full(input_shape, device=mesh_device, fill_value=fill_value, layout=layout) + assert ttnn.is_tensor_storage_on_device(tensor) + output_tensors = ttnn.to_torch(tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device)) + + for output_tensor in output_tensors: + assert_with_pcc(torch_tensor, output_tensor, 0.9999) + assert torch.allclose(torch_tensor, output_tensor) + + @pytest.mark.parametrize( "start", [4, 8, 16, 32], @@ -403,7 +354,6 @@ def test_empty_multi_device(mesh_device, input_shapes): ) def test_empty_like(device, input_shapes): torch_input_tensor = torch.ones((input_shapes), dtype=torch.bfloat16) - torch_output_tensor = torch.empty(torch_input_tensor.shape, dtype=torch.bfloat16) input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT) input_tensor = ttnn.to_device(input_tensor, device) @@ -412,4 +362,28 @@ def test_empty_like(device, input_shapes): output_tensor = ttnn.from_device(output_tensor) output_tensor = ttnn.to_torch(output_tensor) - assert list(torch_output_tensor.shape) == list(output_tensor.shape) + assert list(torch_input_tensor.shape) == list(output_tensor.shape) + + +@pytest.mark.parametrize( + "input_shapes", + [ + [2, 1, 4, 4], # 256x256 + [2, 1280, 8, 8], + [2, 640, 16, 16], + [2, 1280, 8, 8], # 512x512 + [2, 1280, 16, 16], + [2, 1280, 16, 16], + ], +) +def test_empty_like_multi_device(mesh_device, input_shapes): + torch_input_tensor = torch.empty((input_shapes), dtype=torch.bfloat16) + + input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT) + input_tensor = ttnn.to_device(input_tensor, mesh_device) + output_tensor = ttnn.empty_like(input_tensor, layout=ttnn.TILE_LAYOUT) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device)) + for output_tensor in output_tensors: + assert list(torch_input_tensor.shape) == list(output_tensor.shape) diff --git a/tt-train/sources/ttml/core/tt_tensor_utils.cpp b/tt-train/sources/ttml/core/tt_tensor_utils.cpp index ac8e946f19e..d9f20c55ff1 100644 --- a/tt-train/sources/ttml/core/tt_tensor_utils.cpp +++ b/tt-train/sources/ttml/core/tt_tensor_utils.cpp @@ -166,11 +166,10 @@ tt::tt_metal::Tensor full( (padded[2] + additional_padding_h), (padded[3] + additional_padding_w), }); - // temporary solution to avoid using the device, and use only MeshDevice in highlevel api - return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0))); + return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device)); } // if not padding available, we can just create a tensor with the given shape - return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0))); + return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device)); } tt::tt_metal::Tensor zeros(const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, DataType dtype) { diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 6afcf05d46a..02ffbf54acb 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -391,6 +391,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/creation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp @@ -620,6 +621,7 @@ set(TTNN_PUBLIC_LINK_DIRS "") set(TTNN_PRECOMPILED_HEADERS ${PROJECT_SOURCE_DIR}/tt_metal/tt_stl/reflection.hpp ${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/operation.hpp + ${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/any_device.hpp ${PROJECT_SOURCE_DIR}/tt_metal/third_party/tracy/public/tracy/Tracy.hpp ${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/device_api_metal.h ${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/cluster.h diff --git a/ttnn/cpp/pybind11/operations/creation.hpp b/ttnn/cpp/pybind11/operations/creation.hpp index e7c59a44e20..6581766ccb0 100644 --- a/ttnn/cpp/pybind11/operations/creation.hpp +++ b/ttnn/cpp/pybind11/operations/creation.hpp @@ -15,9 +15,141 @@ namespace py = pybind11; namespace ttnn { namespace operations { namespace creation { - namespace detail { +template +auto create_pybind_full_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const std::vector& shape, + const fill_value_t fill_value, + const std::optional& dtype, + const std::optional& layout, + const std::optional> device, + const std::optional& memory_config, + std::optional& optional_output_tensor, + uint8_t queue_id) -> ttnn::Tensor { + return self( + queue_id, + ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, + fill_value, + dtype, + layout, + device, + memory_config, + optional_output_tensor); + }, + py::arg("shape"), + py::arg("fill_value"), + py::arg("dtype") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("device") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("optional_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}; +} + +template +auto create_pybind_full_with_hard_coded_value_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const std::vector& shape, + const std::optional& dtype, + const std::optional& layout, + const std::optional> device, + const std::optional& memory_config) -> ttnn::Tensor { + return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config); + }, + py::arg("shape"), + py::arg("dtype") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("device") = std::nullopt, + py::arg("memory_config") = std::nullopt}; +} + +template +auto create_pybind_full_like_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const ttnn::Tensor& tensor, + const fill_value_t fill_value, + const std::optional& dtype, + const std::optional& layout, + const std::optional> device, + const std::optional& memory_config, + std::optional& optional_output_tensor, + uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + }, + py::arg("tensor"), + py::arg("fill_value"), + py::arg("dtype") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("device") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("optional_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}; +} + +template +auto create_pybind_full_like_with_hard_coded_value_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const ttnn::Tensor& tensor, + const std::optional& dtype, + const std::optional& layout, + const std::optional> device, + const std::optional& memory_config, + std::optional& optional_output_tensor, + uint8_t queue_id) -> ttnn::Tensor { + return self(queue_id, tensor, dtype, layout, device, memory_config, optional_output_tensor); + }, + py::arg("tensor"), + py::arg("dtype") = std::nullopt, + py::arg("layout") = std::nullopt, + py::arg("device") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("optional_tensor") = std::nullopt, + py::arg("queue_id") = ttnn::DefaultQueueId}; +} + +template +auto create_pybind_empty_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const std::vector& shape, + const DataType& dtype, + const Layout& layout, + device_t* device, + const MemoryConfig& memory_config) -> ttnn::Tensor { + return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config); + }, + py::arg("shape"), + py::arg("dtype") = DataType::BFLOAT16, + py::arg("layout") = Layout::ROW_MAJOR, + py::arg("device"), + py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}; +} + +template +auto create_pybind_empty_like_overload() { + return ttnn::pybind_overload_t{ + [](const creation_operation_t& self, + const ttnn::Tensor& reference, + const std::optional& dtype, + const std::optional& layout, + const std::optional> device, + const std::optional& memory_config) -> ttnn::Tensor { + return self(reference, dtype, layout, device, memory_config); + }, + py::arg("tensor"), + py::kw_only(), + py::arg("dtype") = DataType::BFLOAT16, + py::arg("layout") = Layout::ROW_MAJOR, + py::arg("device") = std::nullopt, + py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}; +} + template void bind_full_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( @@ -29,7 +161,7 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati fill_value (float): The value to fill the tensor with. dtype (ttnn.DataType, optional): The data type of the tensor. Defaults to `None`. layout (ttnn.Layout, optional): The layout of the tensor. Defaults to `None`. - device (ttnn.Device, optional): The device on which the tensor will be allocated. Defaults to `None`. + device (ttnn.Device | ttnn.MeshDevice, optional): The device on which the tensor will be allocated. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The memory configuration of the tensor. Defaults to `None`. output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. @@ -53,62 +185,10 @@ void bind_full_operation(py::module& module, const creation_operation_t& operati module, operation, doc, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const std::vector& shape, - const float fill_value, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config, - std::optional& optional_output_tensor, - uint8_t queue_id) -> ttnn::Tensor { - return self( - queue_id, - ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, - fill_value, - dtype, - layout, - device, - memory_config, - optional_output_tensor); - }, - py::arg("shape"), - py::arg("fill_value"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("optional_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const std::vector& shape, - const int fill_value, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config, - std::optional& optional_output_tensor, - uint8_t queue_id) -> ttnn::Tensor { - return self( - queue_id, - ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, - fill_value, - dtype, - layout, - device, - memory_config, - optional_output_tensor); - }, - py::arg("shape"), - py::arg("fill_value"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("optional_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}); + create_pybind_full_overload(), + create_pybind_full_overload(), + create_pybind_full_overload(), + create_pybind_full_overload()); } template @@ -125,7 +205,7 @@ void bind_full_operation_with_hard_coded_value( shape (ttnn.Shape): The shape of the tensor. dtype (ttnn.DataType, optional): The data type of the tensor. Defaults to `None`. layout (ttnn.Layout, optional): The layout of the tensor. Defaults to `None`. - device (ttnn.Device, optional): The device on which the tensor will be allocated. Defaults to `None`. + device (ttnn.Device | ttnn.MeshDevice, optional): The device on which the tensor will be allocated. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The memory configuration of the tensor. Defaults to `None`. Note: @@ -154,20 +234,8 @@ void bind_full_operation_with_hard_coded_value( module, operation, doc, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const std::vector& shape, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config) -> ttnn::Tensor { - return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config); - }, - py::arg("shape"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt}); + create_pybind_full_with_hard_coded_value_overload(), + create_pybind_full_with_hard_coded_value_overload()); } template @@ -181,7 +249,7 @@ void bind_full_like_operation(py::module& module, const creation_operation_t& op fill_value (float | int): The value to fill the tensor with. dtype (ttnn.DataType, optional): The data type of the tensor. Defaults to `None`. layout (ttnn.Layout, optional): The layout of the tensor. Defaults to `None`. - device (ttnn.Device, optional): The device on which the tensor will be allocated. Defaults to `None`. + device (ttnn.Device | ttnn.MeshDevice, optional): The device on which the tensor will be allocated. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The memory configuration of the tensor. Defaults to `None`. output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. @@ -202,46 +270,10 @@ void bind_full_like_operation(py::module& module, const creation_operation_t& op module, operation, doc, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const ttnn::Tensor& tensor, - const float fill_value, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config, - std::optional& optional_output_tensor, - uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); - }, - py::arg("tensor"), - py::arg("fill_value"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("optional_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const ttnn::Tensor& tensor, - const int fill_value, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config, - std::optional& optional_output_tensor, - uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, tensor, fill_value, dtype, layout, device, memory_config, optional_output_tensor); - }, - py::arg("tensor"), - py::arg("fill_value"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("optional_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}); + create_pybind_full_like_overload(), + create_pybind_full_like_overload(), + create_pybind_full_like_overload(), + create_pybind_full_like_overload()); } template @@ -258,7 +290,7 @@ void bind_full_like_operation_with_hard_coded_value( tensor (ttnn.Tensor): The tensor to use as a template for the shape of the new tensor. dtype (ttnn.DataType, optional): The data type of the tensor. Defaults to `None`. layout (ttnn.Layout, optional): The layout of the tensor. Defaults to `None`. - device (ttnn.Device, optional): The device on which the tensor will be allocated. Defaults to `None`. + device (ttnn.Device | ttnn.MeshDevice, optional): The device on which the tensor will be allocated. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The memory configuration of the tensor. Defaults to `None`. output_tensor (ttnn.Tensor, optional): Preallocated output tensor. Defaults to `None`. queue_id (int, optional): command queue id. Defaults to `0`. @@ -286,24 +318,8 @@ void bind_full_like_operation_with_hard_coded_value( module, operation, doc, - ttnn::pybind_overload_t{ - [](const creation_operation_t& self, - const ttnn::Tensor& tensor, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config, - std::optional& optional_output_tensor, - uint8_t queue_id) -> ttnn::Tensor { - return self(queue_id, tensor, dtype, layout, device, memory_config, optional_output_tensor); - }, - py::arg("tensor"), - py::arg("dtype") = std::nullopt, - py::arg("layout") = std::nullopt, - py::arg("device") = std::nullopt, - py::arg("memory_config") = std::nullopt, - py::arg("optional_tensor") = std::nullopt, - py::arg("queue_id") = ttnn::DefaultQueueId}); + create_pybind_full_like_with_hard_coded_value_overload(), + create_pybind_full_like_with_hard_coded_value_overload()); } template @@ -352,7 +368,8 @@ void bind_arange_operation(py::module& module, const creation_operation_t& opera py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}); } -void bind_empty_operation(py::module& module, const std::string& info_doc = "") { +template +void bind_empty_operation(py::module& module, const creation_operation_t& operation, const std::string& info_doc = "") { auto doc = fmt::format( R"doc( Creates a device tensor with uninitialized values of the specified shape, data type, layout, and memory configuration. @@ -375,45 +392,19 @@ void bind_empty_operation(py::module& module, const std::string& info_doc = "") >>> print(tensor) ttnn.Tensor([[[[0.9, 0.21, 0.5], [0.67, 0.11, 0.30]]]], shape=Shape([2, 3]), dtype=DataType::BFLOAT16, layout=Layout::TILE) )doc", - ttnn::empty.base_name(), + operation.base_name(), info_doc); - using EmptyType = decltype(ttnn::empty); bind_registered_operation( module, - ttnn::empty, + operation, doc, - ttnn::pybind_overload_t{ - [](const EmptyType& self, - const std::vector& shape, - const DataType& dtype, - const Layout& layout, - Device* device, - const MemoryConfig& memory_config) -> ttnn::Tensor { - return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config); - }, - py::arg("shape"), - py::arg("dtype") = DataType::BFLOAT16, - py::arg("layout") = Layout::ROW_MAJOR, - py::arg("device"), - py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}, - ttnn::pybind_overload_t{ - [](const EmptyType& self, - const std::vector& shape, - const DataType& dtype, - const Layout& layout, - MeshDevice* device, - const MemoryConfig& memory_config) -> ttnn::Tensor { - return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config); - }, - py::arg("shape"), - py::arg("dtype") = DataType::BFLOAT16, - py::arg("layout") = Layout::ROW_MAJOR, - py::arg("device"), - py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}); + create_pybind_empty_overload(), + create_pybind_empty_overload()); } -void bind_empty_like_operation(py::module& module) { +template +void bind_empty_like_operation(py::module& module, const creation_operation_t& operation) { auto doc = fmt::format( R"doc( Creates a new tensor with the same shape as the given `reference`, but without initializing its values. The data type, layout, device, and memory configuration of the new tensor can be specified. @@ -424,7 +415,7 @@ void bind_empty_like_operation(py::module& module) { Keyword Args: dtype (ttnn.DataType, optional): The desired data type of the output tensor. Defaults to `ttnn.bfloat16`. layout (ttnn.Layout, optional): The desired layout of the output tensor. Defaults to `ttnn.ROW_MAJOR`. - device (ttnn.Device, optional): The device where the output tensor will be allocated. Defaults to `None`. + device (ttnn.Device | ttnn.MeshDevice, optional): The device where the tensor will be allocated. Defaults to `None`. memory_config (ttnn.MemoryConfig, optional): The memory configuration for the operation. Defaults to `ttnn.DRAM_MEMORY_CONFIG`. Returns: @@ -436,28 +427,14 @@ void bind_empty_like_operation(py::module& module) { >>> print(tensor) ttnn.Tensor([[[[0.87, 0.45, 0.22], [0.60, 0.75, 0.25]]]], shape=Shape([2, 3]), dtype=DataType::BFLOAT16, layout=Layout::ROW_MAJOR) )doc", - ttnn::empty_like.base_name()); + operation.base_name()); - using EmptyLikeType = decltype(ttnn::empty_like); bind_registered_operation( module, - ttnn::empty_like, + operation, doc, - ttnn::pybind_overload_t{ - [](const EmptyLikeType& self, - const ttnn::Tensor& reference, - const std::optional& dtype, - const std::optional& layout, - const std::optional>& device, - const std::optional& memory_config) -> ttnn::Tensor { - return self(reference, dtype, layout, device, memory_config); - }, - py::arg("tensor"), - py::kw_only(), - py::arg("dtype") = DataType::BFLOAT16, - py::arg("layout") = Layout::ROW_MAJOR, - py::arg("device") = std::nullopt, - py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG}); + create_pybind_empty_like_overload(), + create_pybind_empty_like_overload()); } } // namespace detail @@ -496,6 +473,7 @@ void py_module(py::module& module) { detail::bind_empty_operation( module, + ttnn::empty, R"doc(Supported dtypes, layouts, and ranks: +----------------------------+---------------------------------+-------------------+ @@ -505,7 +483,7 @@ void py_module(py::module& module) { +----------------------------+---------------------------------+-------------------+ | BFLOAT_8 | TILE | 2, 3, 4 | +----------------------------+---------------------------------+-------------------+)doc"); - detail::bind_empty_like_operation(module); + detail::bind_empty_like_operation(module, ttnn::empty_like); } } // namespace creation diff --git a/ttnn/cpp/ttnn/any_device.hpp b/ttnn/cpp/ttnn/any_device.hpp new file mode 100644 index 00000000000..50f50bfa676 --- /dev/null +++ b/ttnn/cpp/ttnn/any_device.hpp @@ -0,0 +1,40 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "tt_metal/distributed/mesh_device.hpp" +#include "tt_metal/impl/device/device.hpp" + +namespace ttnn { + +// AnyDevice is a wrapper around Device / MeshDevice to use in interfaces that can accept either. +// This class is cheaply copyable, use value semantics to pass it around. +// +// TODO: the eventual goal is to lower this primitive into tt_metal. In the long term, we also want to extend the +// functionality with the "distributed device" semantics. +class AnyDevice { +public: + // Allow implicit conversion for transparent migration. + // Expect the pointers to be non-null, and remain valid for the lifetime of AnyDevice. + AnyDevice(tt::tt_metal::Device* device) : metal_device_{device} {} + AnyDevice(tt::tt_metal::distributed::MeshDevice* mesh_device) : metal_device_{mesh_device} {} + AnyDevice(const AnyDevice&) = default; + AnyDevice& operator=(const AnyDevice&) = default; + AnyDevice(AnyDevice&&) = delete; + AnyDevice& operator=(AnyDevice&&) = delete; + + std::vector get_devices() { + if (auto* device = std::get_if(&metal_device_); device != nullptr) { + return {*device}; + } else { + return std::get(metal_device_)->get_devices(); + } + } + +private: + std::variant metal_device_; +}; + +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/distributed/api.cpp b/ttnn/cpp/ttnn/distributed/api.cpp index 1dc43636648..8924cc3bc6c 100644 --- a/ttnn/cpp/ttnn/distributed/api.cpp +++ b/ttnn/cpp/ttnn/distributed/api.cpp @@ -141,7 +141,10 @@ std::vector get_t3k_physical_device_ids_ring() { } std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& mesh_device) { - auto get_multi_device_workers = [&](const std::vector& workers) { + // For multi-device tensors, returns the number of workers capped by the number of buffers + // Otherwise, returns all available workes from mesh_device. + auto get_workers_for_tensor = [&tensor, &mesh_device]() { + const auto& workers = mesh_device.get_devices(); if (std::holds_alternative(tensor.get_storage()) or std::holds_alternative(tensor.get_storage())) { return std::vector(workers.begin(), workers.begin() + num_buffers_in_tensor(tensor)); @@ -156,49 +159,46 @@ std::vector distribute_tensor_to_mesh(const Tensor& tensor, MeshDevice& [&](const auto& strategy) { using StrategyType = std::decay_t; if constexpr (std::is_same_v) { - auto mesh_view = mesh_device.get_view(); - return mesh_view->get_devices(strategy.shard_mesh); + return mesh_device.get_view()->get_devices(strategy.shard_mesh); } else { - return get_multi_device_workers(mesh_device.get_devices()); + return get_workers_for_tensor(); } }, host_storage.strategy); } else if (std::holds_alternative(tensor.get_storage())) { return tensor.workers; } else { - return get_multi_device_workers(mesh_device.get_devices()); + return get_workers_for_tensor(); } } DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& tensor) { if (tensor.storage_type() == StorageType::MULTI_DEVICE) { + const auto* multi_device_storage = std::get_if(&tensor.get_storage()); TT_ASSERT( - std::holds_alternative(tensor.get_storage()), + multi_device_storage != nullptr, "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); - const auto& tensor_storage = std::get(tensor.get_storage()); - return tensor_storage.strategy; + return multi_device_storage->strategy; } else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + const auto* multi_device_host_storage = std::get_if(&tensor.get_storage()); TT_ASSERT( - std::holds_alternative(tensor.get_storage()), + multi_device_host_storage != nullptr, "Unexpected type {}", tt::stl::get_active_type_name_in_variant(tensor.get_storage())); - const auto& tensor_storage = std::get(tensor.get_storage()); - return tensor_storage.strategy; + return multi_device_host_storage->strategy; } TT_THROW("Tensor is not a multi-device tensor"); } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { - if (std::holds_alternative(multi_device_tensor.get_storage())) { - const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); - if (tensor_storage.has_buffer_for_device_id(device_id)) { - return Tensor{ - DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, - multi_device_tensor.get_legacy_shape(), - multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout()}; - } + if (const auto* tensor_storage = std::get_if(&multi_device_tensor.get_storage()); + tensor_storage != nullptr && tensor_storage->has_buffer_for_device_id(device_id)) { + return Tensor{ + DeviceStorage{tensor_storage->get_buffer_for_device_id(device_id)}, + multi_device_tensor.get_shape(), + multi_device_tensor.get_dtype(), + multi_device_tensor.get_layout()}; } else if (std::holds_alternative(multi_device_tensor.get_storage())) { return multi_device_tensor; } diff --git a/ttnn/cpp/ttnn/operations/core/core.cpp b/ttnn/cpp/ttnn/operations/core/core.cpp index fac198ff601..184f6e139f1 100644 --- a/ttnn/cpp/ttnn/operations/core/core.cpp +++ b/ttnn/cpp/ttnn/operations/core/core.cpp @@ -15,6 +15,7 @@ #include "ttnn/distributed/types.hpp" #include "ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved.hpp" #include "ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded.hpp" +#include "ttnn/tensor/tensor.hpp" namespace ttnn::operations::core { @@ -85,8 +86,8 @@ ttnn::Tensor allocate_tensor_on_device( Layout layout, Device* device, const std::optional& memory_config) { - return tt::tt_metal::allocate_tensor_on_device( - shape, data_type, layout, device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + return tt::tt_metal::allocate_tensor_on_devices( + shape, data_type, layout, {device}, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } ttnn::Tensor allocate_tensor_on_device( @@ -95,8 +96,8 @@ ttnn::Tensor allocate_tensor_on_device( Layout layout, MeshDevice* mesh_device, const std::optional& memory_config) { - return tt::tt_metal::allocate_tensor_on_device( - shape, data_type, layout, mesh_device, memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); + return tt::tt_metal::allocate_tensor_on_devices( + shape, data_type, layout, mesh_device->get_devices(), memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG)); } void copy_host_to_device_tensor(const ttnn::Tensor& host_tensor, ttnn::Tensor device_tensor, uint8_t cq_id) { diff --git a/ttnn/cpp/ttnn/operations/creation.cpp b/ttnn/cpp/ttnn/operations/creation.cpp new file mode 100644 index 00000000000..62c1aaffb94 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/creation.cpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/operations/creation.hpp" + +#include + +#include "ttnn/any_device.hpp" + +namespace ttnn::operations::creation::detail { + +OptionalAnyDevice::OptionalAnyDevice(std::nullopt_t) {} +OptionalAnyDevice::OptionalAnyDevice(ttnn::AnyDevice device) : device_(std::make_optional(device)) {} + +// TODO: some of these won't be needed, as we unify the APIs. +OptionalAnyDevice::OptionalAnyDevice(const std::optional>& device) : + device_(device.has_value() ? std::make_optional(&device->get()) : std::nullopt) {} +OptionalAnyDevice::OptionalAnyDevice( + const std::optional>& mesh_device) : + device_(mesh_device.has_value() ? std::make_optional(&mesh_device->get()) : std::nullopt) {} +OptionalAnyDevice::OptionalAnyDevice(std::reference_wrapper device) : + device_(std::make_optional(&device.get())) {} +OptionalAnyDevice::OptionalAnyDevice(std::reference_wrapper mesh_device) : + device_(std::make_optional(&mesh_device.get())) {} + +OptionalAnyDevice::OptionalAnyDevice(tt::tt_metal::Device& device) : device_(std::make_optional(&device)) {} +OptionalAnyDevice::OptionalAnyDevice(tt::tt_metal::distributed::MeshDevice& mesh_device) : + device_(std::make_optional(&mesh_device)) {} + +} // namespace ttnn::operations::creation::detail diff --git a/ttnn/cpp/ttnn/operations/creation.hpp b/ttnn/cpp/ttnn/operations/creation.hpp index b0502c3aaf4..acd2914c98f 100644 --- a/ttnn/cpp/ttnn/operations/creation.hpp +++ b/ttnn/cpp/ttnn/operations/creation.hpp @@ -4,21 +4,73 @@ #pragma once -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/tensor/tensor_utils.hpp" -#include "ttnn/tensor/types.hpp" -#include "ttnn/operations/numpy/functions.hpp" +#include +#include + #include "tt_metal/impl/dispatch/command_queue.hpp" +#include "ttnn/common/constants.hpp" #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" -#include "ttnn/types.hpp" -#include "ttnn/common/constants.hpp" +#include "ttnn/distributed/types.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" +#include "ttnn/operations/numpy/functions.hpp" +#include "ttnn/any_device.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/tensor_utils.hpp" +#include "ttnn/tensor/types.hpp" +#include "ttnn/types.hpp" namespace ttnn { namespace operations { namespace creation { +namespace detail { + +// Non-type template parameters (NTTPs) disallow floating point values +// This works around that limitation by using a structural type +// https://godbolt.org/z/hxKje3MYe +template +struct boxed { + T value; + consteval boxed(T value) noexcept : value(value) {} + consteval auto invoke() const noexcept -> T { return value; } +}; + +// Helper class to transparently bind instances of Device / MeshDevice along with their reference wrappers to +// AnyDevice +class OptionalAnyDevice { +public: + OptionalAnyDevice() = default; + OptionalAnyDevice(std::nullopt_t); + OptionalAnyDevice(ttnn::AnyDevice device); + OptionalAnyDevice(const std::optional>& device); + OptionalAnyDevice(const std::optional>& mesh_device); + OptionalAnyDevice(std::reference_wrapper device); + OptionalAnyDevice(std::reference_wrapper mesh_device); + OptionalAnyDevice(tt::tt_metal::Device& device); + OptionalAnyDevice(tt::tt_metal::distributed::MeshDevice& mesh_device); + + OptionalAnyDevice(const OptionalAnyDevice&) = default; + OptionalAnyDevice& operator=(const OptionalAnyDevice&) = default; + OptionalAnyDevice(OptionalAnyDevice&&) = delete; + OptionalAnyDevice& operator=(OptionalAnyDevice&&) = delete; + + bool has_value() { return device_.has_value(); } + ttnn::AnyDevice* operator->() { return &(*device_); } + ttnn::AnyDevice operator*() { return *device_; } + +private: + std::optional device_; +}; + +// Converts an instance of AnyDevice to a vector of the underlying Devices. +// TODO: Consider moving the helper into a dedicated header with the related utils. +inline std::vector get_workers_from_device(OptionalAnyDevice device) { + return device.has_value() ? device->get_devices() : std::vector{}; +} + +} // namespace detail + template Tensor create_scalar(T scalar, DataType data_type, Layout layout, Device* device) { using namespace tt::constants; @@ -55,12 +107,12 @@ inline ttnn::Tensor full_impl( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device_arg = std::nullopt, + const std::vector& workers = {}, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { - Device* device = optional_output_tensor.has_value() ? optional_output_tensor.value().device() - : device_arg.has_value() ? &(device_arg.value().get()) - : nullptr; + const std::vector& workers_to_use = + optional_output_tensor.has_value() ? optional_output_tensor->get_workers(/*blocking=*/true) : workers; + Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() : layout.value_or(ttnn::ROW_MAJOR_LAYOUT); DataType dtype_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_dtype() @@ -69,8 +121,9 @@ inline ttnn::Tensor full_impl( optional_output_tensor.has_value() ? optional_output_tensor.value().get_legacy_shape() : shape.value; MemoryConfig mem_cfg = optional_output_tensor.has_value() ? optional_output_tensor.value().memory_config() : memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG); + return numpy::full_impl( - queue_id, shape_value, fill_value, dtype_value, layout_value, device, mem_cfg, optional_output_tensor); + queue_id, shape_value, fill_value, dtype_value, layout_value, workers, mem_cfg, optional_output_tensor); } template @@ -79,27 +132,21 @@ inline ttnn::Tensor full( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device_arg = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt, uint8_t queue_id = ttnn::DefaultQueueId) { - return full_impl(queue_id, shape, fill_value, dtype, layout, device_arg, memory_config, optional_output_tensor); + return full_impl( + queue_id, + shape, + fill_value, + dtype, + layout, + detail::get_workers_from_device(device), + memory_config, + optional_output_tensor); } -namespace detail { - -// Non-type template parameters (NTTPs) disallow floating point values -// This works around that limitation by using a structural type -// https://godbolt.org/z/hxKje3MYe -template -struct boxed { - T value; - consteval boxed(T value) noexcept : value(value) {} - consteval auto invoke() const noexcept -> T { return value; } -}; - -} // namespace detail - template struct FullWith { static constexpr auto fill_value = FillValue.invoke(); @@ -108,7 +155,7 @@ struct FullWith { const ttnn::Shape& shape, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt) { return full(shape, fill_value, dtype, layout, device, memory_config); } @@ -127,7 +174,7 @@ inline ttnn::Tensor full_like_impl( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { Layout layout_value = optional_output_tensor.has_value() ? optional_output_tensor.value().get_layout() @@ -150,7 +197,7 @@ inline ttnn::Tensor full_like_impl( fill_value, dtype_value, layout_value, - device.value_or(*tensor.device()), + device.has_value() ? device->get_devices() : tensor.get_workers(/*blocking=*/true), memory_config.value_or(tensor.memory_config()), optional_output_tensor); } @@ -161,7 +208,7 @@ inline ttnn::Tensor full_like_impl( fill_value, dtype_value, layout_value, - device, + detail::get_workers_from_device(device), memory_config, optional_output_tensor); } @@ -173,7 +220,7 @@ inline ttnn::Tensor full_like( const T fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt) { return full_like_impl(ttnn::DefaultQueueId, tensor, fill_value, dtype, layout, device, memory_config, std::nullopt); } @@ -187,7 +234,7 @@ struct FullLikeWith { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -198,7 +245,7 @@ struct FullLikeWith { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -217,18 +264,9 @@ struct Empty { const ttnn::Shape& shape, const DataType& dtype, const Layout& layout, - Device* device, + ttnn::AnyDevice device, const MemoryConfig& memory_config) { - return allocate_tensor_on_device(shape, dtype, layout, device, memory_config); - } - - static ttnn::Tensor invoke( - const ttnn::Shape& shape, - const DataType& dtype, - const Layout& layout, - distributed::MeshDevice* device, - const MemoryConfig& memory_config) { - return allocate_tensor_on_device(shape, dtype, layout, device, memory_config); + return allocate_tensor_on_devices(shape, dtype, layout, device.get_devices(), memory_config); } }; @@ -237,13 +275,14 @@ struct EmptyLike { const ttnn::Tensor& tensor, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device_arg = std::nullopt, + detail::OptionalAnyDevice device_arg = std::nullopt, const std::optional& memory_config = std::nullopt) { - Device* device = device_arg.has_value() ? &(device_arg.value().get()) : tensor.device(); + const std::vector& devices = + device_arg.has_value() ? device_arg->get_devices() : tensor.get_workers(/*blocking=*/true); Layout layout_value = layout.value_or(tensor.get_layout()); DataType dtype_value = dtype.value_or(tensor.get_dtype()); MemoryConfig mem_cfg = memory_config.value_or(tensor.memory_config()); - return create_device_tensor(tensor.get_shape(), dtype_value, layout_value, device, mem_cfg); + return allocate_tensor_on_devices(tensor.get_shape(), dtype_value, layout_value, devices, mem_cfg); } }; @@ -254,10 +293,18 @@ struct Full { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { - return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + return full_impl( + queue_id, + shape, + fill_value, + dtype, + layout, + detail::get_workers_from_device(device), + memory_config, + optional_output_tensor); } static ttnn::Tensor invoke( @@ -266,10 +313,18 @@ struct Full { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { - return full_impl(queue_id, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + return full_impl( + queue_id, + shape, + fill_value, + dtype, + layout, + detail::get_workers_from_device(device), + memory_config, + optional_output_tensor); } static ttnn::Tensor invoke( @@ -277,11 +332,18 @@ struct Full { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( - ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + ttnn::DefaultQueueId, + shape, + fill_value, + dtype, + layout, + detail::get_workers_from_device(device), + memory_config, + optional_output_tensor); } static ttnn::Tensor invoke( @@ -289,11 +351,18 @@ struct Full { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_impl( - ttnn::DefaultQueueId, shape, fill_value, dtype, layout, device, memory_config, optional_output_tensor); + ttnn::DefaultQueueId, + shape, + fill_value, + dtype, + layout, + detail::get_workers_from_device(device), + memory_config, + optional_output_tensor); } }; @@ -304,7 +373,7 @@ struct FullLike { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -317,7 +386,7 @@ struct FullLike { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -329,7 +398,7 @@ struct FullLike { const float fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -341,7 +410,7 @@ struct FullLike { const int fill_value, const std::optional& dtype = std::nullopt, const std::optional& layout = std::nullopt, - const std::optional>& device = std::nullopt, + detail::OptionalAnyDevice device = std::nullopt, const std::optional& memory_config = std::nullopt, std::optional optional_output_tensor = std::nullopt) { return full_like_impl( @@ -349,6 +418,7 @@ struct FullLike { } }; +// TODO: #14974 - Onboard this API onto AnyDevice. struct Arange { static ttnn::Tensor invoke( const int64_t stop, @@ -385,14 +455,13 @@ struct Arange { } // namespace creation } // namespace operations -constexpr auto full = - ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full", ttnn::operations::creation::Full>(); +constexpr auto full = ttnn::decorators::register_operation<"ttnn::full", ttnn::operations::creation::Full>(); constexpr auto zeros = ttnn::decorators::register_operation<"ttnn::zeros", ttnn::operations::creation::Zeros>(); constexpr auto ones = ttnn::decorators::register_operation<"ttnn::ones", ttnn::operations::creation::Ones>(); constexpr auto empty = ttnn::decorators::register_operation<"ttnn::empty", ttnn::operations::creation::Empty>(); constexpr auto full_like = - ttnn::decorators::register_operation_with_auto_launch_op<"ttnn::full_like", ttnn::operations::creation::FullLike>(); + ttnn::decorators::register_operation<"ttnn::full_like", ttnn::operations::creation::FullLike>(); constexpr auto zeros_like = ttnn::decorators::register_operation<"ttnn::zeros_like", ttnn::operations::creation::ZerosLike>(); constexpr auto ones_like = diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index 331931fd746..31f1ec32efe 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -52,41 +52,42 @@ static Tensor full( uint8_t queue_id, const tt::tt_metal::LegacyShape& shape, T value, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = - MemoryConfig{.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, - std::optional optional_output_tensor = std::nullopt) { + const Layout layout, + const std::vector& devices, + const MemoryConfig& output_mem_config, + std::optional optional_output_tensor) { constexpr DataType data_type = detail::get_data_type(); TensorSpec tensor_spec( shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout), MemoryConfig{}, shape)); auto owned_buffer = tt::tt_metal::owned_buffer::create(tensor_spec.padded_shape().volume()); + // TODO: 15061 - Generalize the header to support generic vector / view types. std::fill(std::begin(owned_buffer), std::end(owned_buffer), value); if (!optional_output_tensor.has_value()) { auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout); - if (device != nullptr) { - output = output.to(device, output_mem_config); + if (!devices.empty()) { + output = output.to(devices, output_mem_config); } return output; } else { - auto device_buffer = - std::get(optional_output_tensor.value().tensor_attributes->storage).get_buffer(); - bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); - - if (using_fast_dispatch && device != nullptr) { - auto& cmd_queue = device->command_queue(queue_id); - if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.get_ptr(), false); + const auto buffers = optional_output_tensor->buffers(); + const bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr); + + for (auto* buffer : buffers) { + if (using_fast_dispatch) { + auto& cmd_queue = buffer->device()->command_queue(queue_id); + if (CommandQueue::default_mode() == CommandQueue::CommandQueueMode::ASYNC) { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.get_ptr(), /*blocking=*/false); + } else { + tt::tt_metal::EnqueueWriteBuffer(cmd_queue, *buffer, owned_buffer.data(), /*blocking=*/false); + } } else { - tt::tt_metal::EnqueueWriteBuffer(cmd_queue, device_buffer, owned_buffer.data(), false); + tt::tt_metal::detail::WriteToBuffer(*buffer, owned_buffer.get()); } - } else { - tt::tt_metal::detail::WriteToBuffer(*device_buffer, owned_buffer.get()); } - return optional_output_tensor.value(); + return *optional_output_tensor; } } @@ -98,27 +99,26 @@ static Tensor full_impl( const tt::tt_metal::LegacyShape& shape, const T value, const DataType data_type, - const Layout layout = Layout::ROW_MAJOR, - Device* device = nullptr, - const MemoryConfig& output_mem_config = - MemoryConfig{.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, - std::optional optional_output_tensor = std::nullopt) { + const Layout layout, + const std::vector& devices, + const MemoryConfig& output_mem_config, + std::optional optional_output_tensor) { switch (data_type) { case DataType::UINT8: { return detail::full( - queue_id, shape, uint8_t(value), layout, device, output_mem_config, optional_output_tensor); + queue_id, shape, uint8_t(value), layout, devices, output_mem_config, optional_output_tensor); } case DataType::UINT16: { return detail::full( - queue_id, shape, uint16_t(value), layout, device, output_mem_config, optional_output_tensor); + queue_id, shape, uint16_t(value), layout, devices, output_mem_config, optional_output_tensor); } case DataType::UINT32: { return detail::full( - queue_id, shape, uint32_t(value), layout, device, output_mem_config, optional_output_tensor); + queue_id, shape, uint32_t(value), layout, devices, output_mem_config, optional_output_tensor); } case DataType::FLOAT32: { return detail::full( - queue_id, shape, float(value), layout, device, output_mem_config, optional_output_tensor); + queue_id, shape, float(value), layout, devices, output_mem_config, optional_output_tensor); } case DataType::BFLOAT16: { return detail::full<::bfloat16>( @@ -126,7 +126,7 @@ static Tensor full_impl( shape, ::bfloat16(static_cast(value)), layout, - device, + devices, output_mem_config, optional_output_tensor); } @@ -134,6 +134,7 @@ static Tensor full_impl( } } +// TODO: #14974 - Can this be deleted, as it is only used in tests? template static Tensor full( const tt::tt_metal::LegacyShape& shape, @@ -143,9 +144,18 @@ static Tensor full( Device* device = nullptr, const MemoryConfig& output_mem_config = MemoryConfig{ .memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}) { - return full_impl(ttnn::DefaultQueueId, shape, value, data_type, layout, device, output_mem_config, std::nullopt); + return full_impl( + ttnn::DefaultQueueId, + shape, + value, + data_type, + layout, + device ? std::vector{device} : std::vector{}, + output_mem_config, + std::nullopt); } +// TODO: #14974 - Can this be deleted, as it is only used in tests? static Tensor zeros( const tt::tt_metal::LegacyShape& shape, const DataType data_type = DataType::BFLOAT16, @@ -156,6 +166,7 @@ static Tensor zeros( return full(shape, 0.0f, data_type, layout, device, output_mem_config); } +// TODO: #14974 - Can this be deleted, as it is only used in tests? static Tensor ones( const tt::tt_metal::LegacyShape& shape, const DataType data_type = DataType::BFLOAT16, @@ -166,54 +177,6 @@ static Tensor ones( return full(shape, 1.0f, data_type, layout, device, output_mem_config); } -template -static Tensor full_like( - const Tensor& input_tensor, - const T value, - std::optional data_type = std::nullopt, - std::optional layout = std::nullopt, - std::optional output_mem_config = std::nullopt) { - DataType data_type_to_use = input_tensor.get_dtype(); - if (data_type.has_value()) { - data_type_to_use = data_type.value(); - } - Layout layout_to_use = input_tensor.get_layout(); - if (layout.has_value()) { - layout_to_use = layout.value(); - } - if (input_tensor.storage_type() == StorageType::DEVICE) { - MemoryConfig output_mem_config_to_use = input_tensor.memory_config(); - if (output_mem_config.has_value()) { - output_mem_config_to_use = output_mem_config.value(); - } - return full( - input_tensor.get_legacy_shape(), - value, - data_type_to_use, - layout_to_use, - input_tensor.device(), - output_mem_config_to_use); - } else { - return full(input_tensor.get_legacy_shape(), value, data_type_to_use, layout_to_use); - } -} - -static Tensor zeros_like( - const Tensor& input_tensor, - std::optional data_type = std::nullopt, - std::optional layout = std::nullopt, - std::optional output_mem_config = std::nullopt) { - return full_like(input_tensor, 0.0f, data_type, layout, output_mem_config); -} - -static Tensor ones_like( - const Tensor& input_tensor, - std::optional data_type = std::nullopt, - std::optional layout = std::nullopt, - std::optional output_mem_config = std::nullopt) { - return full_like(input_tensor, 1.0f, data_type, layout, output_mem_config); -} - template static Tensor arange( const int64_t start, diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 904e1af1dbd..069972451bd 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -840,37 +840,15 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional tra } } -Tensor allocate_tensor_on_device( +Tensor allocate_tensor_on_devices( const ttnn::Shape& shape, DataType data_type, Layout layout, - Device* device, - const MemoryConfig& memory_config, - const std::optional& tile) { - // Top level wrapper to asynchronously create a device tensor (single device) - Tensor device_tensor = Tensor({device}); - - // Save the ref count to later re-set it: - // 1. device_tensor is copied in the lambda by the main thread, which increments the ref count. - // 2. The destruction happens in a worker thread, which doesn't decrement the ref count. - const uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); - device->push_work([shape, data_type, layout, device, memory_config, tile, device_tensor]() mutable { - auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config, tile); - device_tensor.populate_buffers_and_metadata(local_tensor); - }); - device_tensor.tensor_attributes->update_main_thread_ref_count(device, device_tensor_ref_count); - return device_tensor; -} - -Tensor allocate_tensor_on_device( - const ttnn::Shape& shape, - DataType data_type, - Layout layout, - distributed::MeshDevice* mesh_device, + const std::vector& devices, const MemoryConfig& memory_config, const std::optional& tile) { - // Top level wrapper to asynchronously create a device tensor (multi-device) - Tensor device_tensor = Tensor(mesh_device->get_devices()); + // Top level wrapper to asynchronously create a device tensor (single- or multi-device). + Tensor device_tensor = Tensor(devices); TensorSpec tensor_spec( shape.logical_shape(), TensorLayout::fromLegacyPaddedShape(data_type, PageConfig(layout, tile), memory_config, shape)); @@ -879,11 +857,11 @@ Tensor allocate_tensor_on_device( // 1. device_tensor is copied in the lambda by the main thread, which increments the ref count. // 2. The destruction happens in a worker thread, which doesn't decrement the ref count. const uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); - const auto& workers = device_tensor.get_workers(); - uint32_t num_workers = workers.size(); + const auto& workers_in_use = device_tensor.get_workers(); + uint32_t num_workers = workers_in_use.size(); for (int worker_index = 0; worker_index < num_workers; ++worker_index) { - auto& worker = workers[worker_index]; + auto& worker = devices[worker_index]; worker->push_work([worker, device_tensor, tensor_spec, worker_index]() mutable { auto local_tensor = create_device_tensor(tensor_spec, worker); insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); @@ -894,7 +872,7 @@ Tensor allocate_tensor_on_device( } }); } - device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); + device_tensor.tensor_attributes->update_main_thread_ref_count(workers_in_use.at(0), device_tensor_ref_count); return device_tensor; } diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 7abb5c612f3..7a2976ac8f2 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -75,6 +75,7 @@ struct Tensor { // Shared pointer to all attributes associated with this tensor // Can be safely passed between threads when the tensor is copied std::shared_ptr tensor_attributes = nullptr; + // Tensor gets worker queue handle through the device std::vector workers = {}; bool deallocate_through_destructor = false; @@ -257,7 +258,7 @@ struct Tensor { } else if (storage_type == tt::tt_metal::StorageType::MULTI_DEVICE) { std::vector buffers; auto storage = std::get(this->get_storage()); - for (auto buffer : storage.get_buffers()) { + for (const auto& buffer : storage.get_buffers()) { buffers.push_back(buffer.get()); } return buffers; @@ -365,18 +366,11 @@ void memcpy( void memcpy(Tensor& dst, const void* src, const std::optional transfer_size = std::nullopt); void memcpy(Tensor& dst, const Tensor& src, const std::optional transfer_size = std::nullopt); -Tensor allocate_tensor_on_device( - const ttnn::Shape& shape, - DataType data_type, - Layout layout, - Device* device, - const MemoryConfig& memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, - const std::optional& tile = std::nullopt); -Tensor allocate_tensor_on_device( +Tensor allocate_tensor_on_devices( const ttnn::Shape& shape, DataType data_type, Layout layout, - distributed::MeshDevice* mesh_device, + const std::vector& devices, const MemoryConfig& memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, const std::optional& tile = std::nullopt); void write_tensor(const Tensor& host_tensor, Tensor device_tensor, uint8_t cq_id = ttnn::DefaultQueueId); diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 9e3e225ea1e..6d11ff06bd6 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -183,7 +183,10 @@ Tensor to_weight_tile_layout( // Converts convolution weights to tilized 2d matrix layout. // Returns a new tensor with layout=Tile Tensor convert_conv_weight_tensor_to_tiled_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype) { TT_ASSERT( conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); @@ -213,7 +216,10 @@ Tensor convert_conv_weight_tensor_to_tiled_layout( // Converts convolution weights to tilized 2d matrix layout. // Returns a new tensor with layout=Tile Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + const Tensor& conv_weight_tensor, + uint32_t in1_block_h, + uint32_t in1_block_w, + std::optional output_dtype) { TT_ASSERT( conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); @@ -623,8 +629,9 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) // Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers. // No need to check for borrowed storage. if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - tensor.tensor_attributes->num_shards_to_be_populated > 1) + tensor.tensor_attributes->num_shards_to_be_populated > 1) { return tensor; + } if (tensor.storage_type() == StorageType::BORROWED) { ZoneScopedN("CopyBorrowedStorage");