Skip to content

Commit

Permalink
#14974: ttnn::{full,empty}_like Tensor creation API for MeshDevice (#…
Browse files Browse the repository at this point in the history
…15333)

### Ticket
#14974 

### Problem description
Extensions to `Tensor` creation APIs to support `MeshDevice`.

### What's changed
Follow up from #15191:
* Extend support for `MeshDevice` in `ttnn::empty_like`, `ttnn::full`,
`ttnn::full_like`.
* Minor refactor of `Tensor` allocation functions.

### Checklist
- [X] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12054314234)
(the PR was updated to address the static checks)
- [X] New/Existing tests provide coverage for changes
- [X] [T3K unit tests + frequent
tests](https://github.com/tenstorrent/tt-metal/actions/runs/12054324195)
(the failures are the same as on main)
  • Loading branch information
omilyutin-tt authored Nov 27, 2024
1 parent d183d61 commit ed6dda9
Show file tree
Hide file tree
Showing 14 changed files with 645 additions and 482 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <optional>
#include <variant>

#include "buffers/buffer_constants.hpp"
Expand All @@ -24,11 +25,11 @@ using ::tt::tt_metal::TensorMemoryLayout;

class MultiDeviceTensorCreationTest : public T3kMultiDeviceFixture, public ::testing::WithParamInterface<bool> {};

TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) {
TEST_P(MultiDeviceTensorCreationTest, Empty) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

const auto mesh_replicated_tensor = ttnn::empty(
const Tensor mesh_replicated_tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
Expand All @@ -39,7 +40,133 @@ TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) {
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, EmptyLike) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

const Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

const Tensor mesh_replicated_tensor = ttnn::empty_like(
tensor,
DataType::BFLOAT16,
Layout::ROW_MAJOR,
*mesh_device,
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, Full) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

const Tensor mesh_replicated_tensor = ttnn::full(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
/*fill_value=*/42,
DataType::BFLOAT16,
Layout::ROW_MAJOR,
std::ref(*mesh_device),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), ttnn::SimpleShape({32, 32}));
EXPECT_EQ(mesh_replicated_tensor.dtype(), DataType::BFLOAT16);
EXPECT_EQ(mesh_replicated_tensor.layout(), Layout::ROW_MAJOR);

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, FullLike) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

Tensor mesh_replicated_tensor = ttnn::full_like(
tensor,
/*fill_value=*/42,
/*dtype=*/std::nullopt,
/*layout=*/std::nullopt,
std::ref(*mesh_device));

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape());
EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype());
EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

TEST_P(MultiDeviceTensorCreationTest, FullLikeWithOptTensor) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

ASSERT_FALSE(mesh_device->get_devices().empty());

Tensor tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device->get_devices().at(0),
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(tensor.storage_type(), StorageType::DEVICE);
EXPECT_EQ(tensor.get_workers().size(), 1);

Tensor opt_output = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device,
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

Tensor mesh_replicated_tensor = ttnn::full_like(
tensor,
/*fill_value=*/42,
/*dtype=*/std::nullopt,
/*layout=*/std::nullopt,
/*device=*/std::nullopt,
/*memory_config=*/std::nullopt,
opt_output);

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());
EXPECT_EQ(mesh_replicated_tensor.shape(), tensor.shape());
EXPECT_EQ(mesh_replicated_tensor.dtype(), tensor.dtype());
EXPECT_EQ(mesh_replicated_tensor.layout(), tensor.layout());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);
EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

Expand Down
140 changes: 57 additions & 83 deletions tests/ttnn/unit_tests/operations/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,62 +32,6 @@ def test_zeros_like(device, input_shape):
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
def test_zeros_like_bf8b(device, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.zeros_like(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
output_tensor = ttnn.zeros_like(input_tensor)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
@pytest.mark.parametrize(
"layout",
[ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE],
)
def test_zeros_like_opt(device, layout, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.zeros_like(torch_input_tensor)
opt_tensor = torch.ones(input_shape, dtype=torch.bfloat16)
opt_tensor = ttnn.from_torch(
opt_tensor, ttnn.bfloat16, layout=layout, device=device, memory_config=ttnn.L1_MEMORY_CONFIG
)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=layout)
input_tensor = ttnn.to_device(input_tensor, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.zeros_like(input_tensor, optional_tensor=opt_tensor, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

assert ttnn.is_tensor_storage_on_device(opt_tensor)
opt_tensor = ttnn.from_device(opt_tensor)
opt_tensor = ttnn.to_torch(opt_tensor)

assert_with_pcc(torch_output_tensor, opt_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, opt_tensor)


@pytest.mark.parametrize(
"input_shape",
[
Expand All @@ -110,35 +54,13 @@ def test_ones_like(device, input_shape):
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
],
)
def test_ones_like_bf8b(device, input_shape):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
torch_output_tensor = torch.ones_like(torch_input_tensor)

input_tensor = ttnn.from_torch(torch_input_tensor, dtype=ttnn.bfloat8_b, layout=ttnn.TILE_LAYOUT, device=device)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.ones_like(input_tensor)
assert ttnn.is_tensor_storage_on_device(output_tensor)
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor).to(torch.bfloat16)

assert_with_pcc(torch_output_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_output_tensor, output_tensor)


@pytest.mark.parametrize(
"input_shape",
[[32, 32], [5, 96, 64], [1, 2, 64, 64], [1, 2, 4, 64, 64]],
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
def test_full_like(device, input_shape, fill_value):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
Expand All @@ -161,7 +83,7 @@ def test_full_like(device, input_shape, fill_value):
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
def test_full_like_bf8b(device, input_shape, fill_value):
torch_input_tensor = torch.rand((input_shape), dtype=torch.bfloat16)
Expand All @@ -187,7 +109,7 @@ def test_full_like_bf8b(device, input_shape, fill_value):
)
@pytest.mark.parametrize(
"fill_value",
[-5, 3, 15, 25],
[-5.25, 0, 1.0],
)
@pytest.mark.parametrize(
"layout",
Expand Down Expand Up @@ -286,6 +208,7 @@ def test_full(device, input_shape, fill_value, layout):
[
[32, 32],
[5, 96, 64],
[1, 50257],
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -314,6 +237,34 @@ def test_full_with_opt_tensor(device, input_shape, layout, fill_value):
assert torch.allclose(torch_tensor, opt_tensor)


@pytest.mark.parametrize(
"input_shape",
[
[32, 32],
[5, 96, 64],
[1, 50257],
],
)
@pytest.mark.parametrize(
"fill_value",
[-5.25, 0, 1.0],
)
@pytest.mark.parametrize(
"layout",
[ttnn.Layout.ROW_MAJOR, ttnn.Layout.TILE],
)
def test_full_multi_device(mesh_device, input_shape, fill_value, layout):
torch_tensor = torch.full(input_shape, dtype=torch.bfloat16, fill_value=fill_value)

tensor = ttnn.full(input_shape, device=mesh_device, fill_value=fill_value, layout=layout)
assert ttnn.is_tensor_storage_on_device(tensor)
output_tensors = ttnn.to_torch(tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device))

for output_tensor in output_tensors:
assert_with_pcc(torch_tensor, output_tensor, 0.9999)
assert torch.allclose(torch_tensor, output_tensor)


@pytest.mark.parametrize(
"start",
[4, 8, 16, 32],
Expand Down Expand Up @@ -403,7 +354,6 @@ def test_empty_multi_device(mesh_device, input_shapes):
)
def test_empty_like(device, input_shapes):
torch_input_tensor = torch.ones((input_shapes), dtype=torch.bfloat16)
torch_output_tensor = torch.empty(torch_input_tensor.shape, dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, device)
Expand All @@ -412,4 +362,28 @@ def test_empty_like(device, input_shapes):
output_tensor = ttnn.from_device(output_tensor)
output_tensor = ttnn.to_torch(output_tensor)

assert list(torch_output_tensor.shape) == list(output_tensor.shape)
assert list(torch_input_tensor.shape) == list(output_tensor.shape)


@pytest.mark.parametrize(
"input_shapes",
[
[2, 1, 4, 4], # 256x256
[2, 1280, 8, 8],
[2, 640, 16, 16],
[2, 1280, 8, 8], # 512x512
[2, 1280, 16, 16],
[2, 1280, 16, 16],
],
)
def test_empty_like_multi_device(mesh_device, input_shapes):
torch_input_tensor = torch.empty((input_shapes), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, mesh_device)
output_tensor = ttnn.empty_like(input_tensor, layout=ttnn.TILE_LAYOUT)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device))
for output_tensor in output_tensors:
assert list(torch_input_tensor.shape) == list(output_tensor.shape)
5 changes: 2 additions & 3 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ tt::tt_metal::Tensor full(
(padded[2] + additional_padding_h),
(padded[3] + additional_padding_w),
});
// temporary solution to avoid using the device, and use only MeshDevice in highlevel api
return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0)));
return ttnn::full(padded_shape, value, dtype, Layout::TILE, std::ref(*device));
}
// if not padding available, we can just create a tensor with the given shape
return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device->get_device(0)));
return ttnn::full(shape, value, dtype, Layout::TILE, std::ref(*device));
}

tt::tt_metal::Tensor zeros(const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, DataType dtype) {
Expand Down
2 changes: 2 additions & 0 deletions ttnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ set(ALL_TTNN_SRCS
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/creation.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/sharding_utilities.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/uniform/uniform_pybind.cpp
Expand Down Expand Up @@ -620,6 +621,7 @@ set(TTNN_PUBLIC_LINK_DIRS "")
set(TTNN_PRECOMPILED_HEADERS
${PROJECT_SOURCE_DIR}/tt_metal/tt_stl/reflection.hpp
${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/operation.hpp
${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/any_device.hpp
${PROJECT_SOURCE_DIR}/tt_metal/third_party/tracy/public/tracy/Tracy.hpp
${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/device_api_metal.h
${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/cluster.h
Expand Down
Loading

0 comments on commit ed6dda9

Please sign in to comment.