Skip to content

Commit

Permalink
#14974: ttnn::empty Tensor creation API for MeshDevice (#15191)
Browse files Browse the repository at this point in the history
### Ticket
#14974 

### Problem description
Extensions to `Tensor` creation APIs to support `MeshDevice`.

### What's changed
Overload for `ttnn::empty` to support `MeshDevice`. The tensor
distribution currently uses replication strategy.

Minor formatting fixes / code comments.

### Checklist
- [x] Post commit CI passes
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
omilyutin-tt authored Nov 20, 2024
1 parent 75d7107 commit eacb47a
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 47 deletions.
1 change: 1 addition & 0 deletions tests/ttnn/unit_tests/gtests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(TTNN_TENSOR_UNIT_TESTS_SRC
${CMAKE_CURRENT_SOURCE_DIR}/tensor/common_tensor_test_utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_tensor_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_multi_device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_create_tensor_with_layout.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_shape_base.cpp
${CMAKE_CURRENT_SOURCE_DIR}/tensor/test_sharding_with_alignment.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@

// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.
//
// SPDX-License-Identifier: Apache-2.0

#include <variant>

#include "buffers/buffer_constants.hpp"
#include "gtest/gtest.h"
#include "ttnn/cpp/ttnn/operations/creation.hpp"
#include "ttnn/cpp/ttnn/tensor/types.hpp"
#include "ttnn/distributed/api.hpp"
#include "ttnn/tensor/enum_types.hpp"
#include "ttnn_test_fixtures.hpp"

namespace ttnn::distributed::test {
namespace {

using ::tt::tt_metal::BufferType;
using ::tt::tt_metal::Layout;
using ::tt::tt_metal::MemoryConfig;
using ::tt::tt_metal::StorageType;
using ::tt::tt_metal::TensorMemoryLayout;

class MultiDeviceTensorCreationTest : public T3kMultiDeviceFixture, public ::testing::WithParamInterface<bool> {};

TEST_P(MultiDeviceTensorCreationTest, CreateEmpty) {
MeshDevice* mesh_device = this->mesh_device_.get();
mesh_device->enable_async(GetParam());

const auto mesh_replicated_tensor = ttnn::empty(
ttnn::Shape(std::array<uint32_t, 2>{32, 32}),
DataType::BFLOAT16,
Layout::ROW_MAJOR,
mesh_device,
MemoryConfig{TensorMemoryLayout::INTERLEAVED, BufferType::DRAM, std::nullopt});

EXPECT_EQ(mesh_replicated_tensor.storage_type(), StorageType::MULTI_DEVICE);
EXPECT_EQ(mesh_replicated_tensor.get_workers().size(), mesh_device->num_devices());

const auto distributed_tensor_config = get_distributed_tensor_config_from_tensor(mesh_replicated_tensor);

EXPECT_TRUE(std::holds_alternative<ReplicateTensor>(distributed_tensor_config));
}

INSTANTIATE_TEST_SUITE_P(AllTests, MultiDeviceTensorCreationTest, ::testing::Bool());

} // namespace
} // namespace ttnn::distributed::test
28 changes: 24 additions & 4 deletions tests/ttnn/unit_tests/operations/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,16 +347,14 @@ def test_arange(device, start, end, step):
[1, 1, 32, 32],
[1, 1, 320, 384],
[1, 3, 320, 384],
[1, 3, 180, 64],
[2, 640, 64, 64],
[2, 1280, 64, 64],
],
)
def test_empty(device, input_shapes):
torch_input_tensor = torch.ones((input_shapes), dtype=torch.bfloat16)
torch_output_tensor = torch.empty(torch_input_tensor.shape, dtype=torch.bfloat16)
torch_output_tensor = torch.empty((input_shapes), dtype=torch.bfloat16)

input_tensor = ttnn.from_torch(torch_input_tensor, layout=ttnn.TILE_LAYOUT)
input_tensor = ttnn.to_device(input_tensor, device)
output_tensor = ttnn.empty(input_shapes, ttnn.bfloat16, ttnn.TILE_LAYOUT, device, ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
Expand All @@ -365,6 +363,28 @@ def test_empty(device, input_shapes):
assert list(torch_output_tensor.shape) == list(output_tensor.shape)


@pytest.mark.parametrize(
"input_shapes",
[
[1, 1, 32, 32],
[1, 1, 320, 384],
[1, 3, 320, 384],
[1, 3, 180, 64],
[2, 640, 64, 64],
[2, 1280, 64, 64],
],
)
def test_empty_multi_device(mesh_device, input_shapes):
torch_output_tensor = torch.empty((input_shapes), dtype=torch.bfloat16)

output_tensor = ttnn.empty(input_shapes, ttnn.bfloat16, ttnn.TILE_LAYOUT, mesh_device, ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
output_tensor = ttnn.from_device(output_tensor)
output_tensors = ttnn.to_torch(output_tensor, mesh_composer=ttnn.ListMeshToTensor(mesh_device))
for output_tensor in output_tensors:
assert list(torch_output_tensor.shape) == list(output_tensor.shape)


@pytest.mark.parametrize(
"input_shapes",
[
Expand Down
3 changes: 1 addition & 2 deletions tt-train/sources/ttml/core/tt_tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,7 @@ tt::tt_metal::Tensor ones_like(const tt::tt_metal::Tensor& tensor) {

tt::tt_metal::Tensor empty(
const ttnn::Shape& shape, ttnn::distributed::MeshDevice* device, const MemoryConfig& memory_config) {
// temporary solution to avoid using the device, and use only MeshDevice in highlevel api
return ttnn::empty(shape, DataType::BFLOAT16, Layout::TILE, device->get_device(0), memory_config);
return ttnn::empty(shape, DataType::BFLOAT16, Layout::TILE, device, memory_config);
}

tt::tt_metal::Tensor full(
Expand Down
16 changes: 15 additions & 1 deletion ttnn/cpp/pybind11/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ void bind_empty_operation(py::module& module, const std::string& info_doc = "")
shape (List[int]): The shape of the tensor to be created.
dtype (ttnn.DataType, optional): The tensor data type. Defaults to `ttnn.bfloat16`.
layout (ttnn.Layout, optional): The tensor layout. Defaults to `ttnn.ROW_MAJOR`.
device (ttnn.Device): The device where the tensor will be allocated.
device (ttnn.Device | ttnn.MeshDevice): The device where the tensor will be allocated.
memory_config (ttnn.MemoryConfig, optional): The memory configuration for the operation. Defaults to `ttnn.DRAM_MEMORY_CONFIG`.
Returns:
Expand Down Expand Up @@ -372,6 +372,20 @@ void bind_empty_operation(py::module& module, const std::string& info_doc = "")
py::arg("dtype") = DataType::BFLOAT16,
py::arg("layout") = Layout::ROW_MAJOR,
py::arg("device"),
py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG},
ttnn::pybind_overload_t{
[](const EmptyType& self,
const std::vector<uint32_t>& shape,
const DataType& dtype,
const Layout& layout,
MeshDevice* device,
const MemoryConfig& memory_config) -> ttnn::Tensor {
return self(ttnn::Shape{tt::tt_metal::LegacyShape{shape}}, dtype, layout, device, memory_config);
},
py::arg("shape"),
py::arg("dtype") = DataType::BFLOAT16,
py::arg("layout") = Layout::ROW_MAJOR,
py::arg("device"),
py::arg("memory_config") = ttnn::DRAM_MEMORY_CONFIG});
}

Expand Down
9 changes: 3 additions & 6 deletions ttnn/cpp/ttnn/operations/core/core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,12 @@ ttnn::Tensor to_device(
const ttnn::Tensor& tensor, MeshDevice* mesh_device, const std::optional<MemoryConfig>& memory_config) {
auto mem_config = memory_config.value_or(ttnn::DRAM_MEMORY_CONFIG);
// Currently no direct sharded write support in BLACKHOLE due to alignment issue
if(mem_config.is_sharded () and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG);
if (mem_config.is_sharded() and (mesh_device->arch() == tt::ARCH::BLACKHOLE)) {
auto interleaved_tensor = tensor.to(mesh_device, ttnn::DRAM_MEMORY_CONFIG);
return ttnn::interleaved_to_sharded(ttnn::DefaultQueueId, interleaved_tensor, mem_config, std::nullopt);
}
else {
} else {
return tensor.to(mesh_device, mem_config);
}


}

ttnn::Tensor allocate_tensor_on_device(
Expand Down
23 changes: 16 additions & 7 deletions ttnn/cpp/ttnn/operations/creation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,13 +218,22 @@ inline constexpr ZerosLike zeros_like{};
inline constexpr OnesLike ones_like{};

struct Empty {
static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const DataType& dtype,
const Layout& layout,
Device* device,
const MemoryConfig& memory_config) {
return create_device_tensor(shape, dtype, layout, device, memory_config);
static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const DataType& dtype,
const Layout& layout,
Device* device,
const MemoryConfig& memory_config) {
return allocate_tensor_on_device(shape, dtype, layout, device, memory_config);
}

static ttnn::Tensor invoke(
const ttnn::Shape& shape,
const DataType& dtype,
const Layout& layout,
distributed::MeshDevice* device,
const MemoryConfig& memory_config) {
return allocate_tensor_on_device(shape, dtype, layout, device, memory_config);
}
};

Expand Down
5 changes: 2 additions & 3 deletions ttnn/cpp/ttnn/operations/numpy/functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,13 @@ static Tensor full(
auto owned_buffer = tt::tt_metal::owned_buffer::create<T>(tt::tt_metal::compute_volume(shape));
std::fill(std::begin(owned_buffer), std::end(owned_buffer), value);

if(!optional_output_tensor.has_value()){
if (!optional_output_tensor.has_value()){
auto output = Tensor(OwnedStorage{owned_buffer}, shape, data_type, layout);
if (device != nullptr) {
output = output.to(device, output_mem_config);
}
return output;
}
else {
} else {
auto device_buffer = std::get<DeviceStorage>(optional_output_tensor.value().tensor_attributes->storage).get_buffer();
bool using_fast_dispatch = (std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr);

Expand Down
50 changes: 31 additions & 19 deletions ttnn/cpp/ttnn/tensor/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) {
this->set_dtype(other.get_dtype());
this->set_layout(other.get_layout());
this->set_tile(other.get_tile());

// Populate storage container with buffers + shapes
std::visit(
[this](auto&& storage) {
Expand Down Expand Up @@ -800,10 +801,19 @@ void memcpy(Tensor& dst, const Tensor& src, const std::optional<std::size_t> tra
}

Tensor allocate_tensor_on_device(
const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional<Tile>& tile) {
const ttnn::Shape& shape,
DataType data_type,
Layout layout,
Device* device,
const MemoryConfig& memory_config,
const std::optional<Tile>& tile) {
// Top level wrapper to asynchronously create a device tensor (single device)
Tensor device_tensor = Tensor({device});
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();

// Save the ref count to later re-set it:
// 1. device_tensor is copied in the lambda by the main thread, which increments the ref count.
// 2. The destruction happens in a worker thread, which doesn't decrement the ref count.
const uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();
device->push_work([shape, data_type, layout, device, memory_config, tile, device_tensor]() mutable {
auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config, tile);
device_tensor.populate_buffers_and_metadata(local_tensor);
Expand All @@ -818,31 +828,33 @@ Tensor allocate_tensor_on_device(
Layout layout,
distributed::MeshDevice* mesh_device,
const MemoryConfig& memory_config,
const std::optional<Tile>& tile
) {
const std::optional<Tile>& tile) {
// Top level wrapper to asynchronously create a device tensor (multi-device)
Tensor device_tensor = Tensor(mesh_device->get_devices());
uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();

// Save the ref count to later re-set it:
// 1. device_tensor is copied in the lambda by the main thread, which increments the ref count.
// 2. The destruction happens in a worker thread, which doesn't decrement the ref count.
const uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count();
const auto& workers = device_tensor.get_workers();
uint32_t num_workers = workers.size();

for (int worker_index = 0; worker_index < num_workers; ++worker_index) {
auto& worker = workers[worker_index];
worker->push_work([shape, data_type, layout, worker, memory_config, tile, device_tensor, worker_index]() mutable {
auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config, tile);
insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index);

uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++;
if (not num_workers_completed) {
device_tensor.set_shape(ttnn::Shape(shape));
device_tensor.set_dtype(data_type);
device_tensor.set_layout(layout);
if (tile.has_value()) {
device_tensor.set_tile(tile.value());
worker->push_work(
[shape, data_type, layout, worker, memory_config, tile, device_tensor, worker_index]() mutable {
auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config, tile);
insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index);

uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++;
if (not num_workers_completed) {
device_tensor.set_shape(local_tensor.get_shape());
device_tensor.set_dtype(local_tensor.get_dtype());
device_tensor.set_layout(local_tensor.get_layout());
device_tensor.set_tile(local_tensor.get_tile());
device_tensor.tensor_attributes->metadata_populated = true;
}
device_tensor.tensor_attributes->metadata_populated = true;
}
});
});
}
device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count);
return device_tensor;
Expand Down
1 change: 0 additions & 1 deletion ttnn/cpp/ttnn/tensor/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ struct Tensor {
TT_THROW("Cannot get the device from a tensor without an allocated buffer");
return buffer->device();
} else if (this->storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) {
auto &storage = std::get<MultiDeviceStorage>(this->get_storage());
return this->get_workers().at(0);
} else {
TT_THROW("Cannot get the device from a tensor with host storage");
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/tensor/tensor_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ void insert_buffer_and_shape_for_device(
Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional<int> buffer_index) {
ZoneScopedN("InsertBufferAndShapeForDevice");
std::visit(
[target_device, &shard, &tensor_to_modify, buffer_index](auto&& s) {
[target_device, &shard, buffer_index](auto&& s) {
using T = std::decay_t<decltype(s)>;
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) {
s.insert_buffer_and_shape_for_device(
Expand Down
6 changes: 3 additions & 3 deletions ttnn/cpp/ttnn/tensor/tensor_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,14 @@ bool is_arch_whb0(const tt::ARCH& arch);
bool is_cpu_tensor(const Tensor& tensor);
bool is_device_tensor(const Tensor& tensor);

// Given a multi-device tensor, and a function that transforms a tensor, apply the function to all per-device
// Given a multi-device tensor, and a function that transforms a tensor, applies the function to all per-device
// tensors.
Tensor transform(const Tensor& tensor, std::function<Tensor(const Tensor&)> transform_func);

// Given a multi-device tensor, and a callable, apply the function to all per-device tensors.
// Given a multi-device tensor, and a callable, applies the function to all per-device tensors.
void apply(const Tensor& tensor, std::function<void(const Tensor&)> callable);

// Given a multi-device tensor, return all the devices it is mapped to.
// Given a multi-device tensor, returns all the devices it is mapped to.
std::vector<Device*> get_devices(const Tensor& multi_device_tensor);

uint32_t num_buffers_in_tensor(const Tensor& tensor);
Expand Down

0 comments on commit eacb47a

Please sign in to comment.