From cb68490a01cc7defc700f98802800d622e254c4e Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> Date: Thu, 5 Sep 2024 20:25:20 -0700 Subject: [PATCH] Handle Shapes with 0 dimensions in Tensor compute_strides (#12286) --- ttnn/cpp/ttnn/tensor/tensor.cpp | 14 +------------- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 15 ++++++++++++--- ttnn/cpp/ttnn/tensor/types.hpp | 1 - 3 files changed, 13 insertions(+), 17 deletions(-) diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 6a3866319f1..325f16bd7bc 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -494,19 +494,7 @@ StorageType Tensor::storage_type() const { this->get_storage()); } -namespace detail { -const Shape compute_strides(const Shape& shape) { - auto num_elements = compute_volume(shape); - std::vector strides; - for (std::int32_t index = 0; index < shape.rank(); index++) { - num_elements /= shape[index]; - strides.push_back(num_elements); - } - return strides; -} -} // namespace detail - -const Shape Tensor::strides() const { return detail::compute_strides(this->get_legacy_shape()); } +const Shape Tensor::strides() const { return Shape(tt::tt_metal::compute_strides(this->get_legacy_shape())); } uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); } diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index cafbdda083c..711d34022c9 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -46,17 +46,26 @@ static std::size_t compute_volume(const T& shape) { return volume; } -static std::vector compute_strides(Shape shape) { +static std::vector compute_strides(const Shape& shape) { + if (shape.rank() == 0) + return {}; + auto num_elements = compute_volume(shape); - std::vector strides; + std::vector strides; for (std::int32_t index = 0; index < shape.rank(); index++) { + if (shape[index] == 0) { + // Insert 0 to indicate no memory access for this dimension + strides.push_back(0); + continue; + } + num_elements /= shape[index]; strides.push_back(num_elements); } return strides; } -static int compute_flat_indices(vector indices, vector strides) { +static int compute_flat_indices(const vector& indices, const vector strides) { int flat_index = 0; for (auto i = 0; i < indices.size(); i++) { flat_index += indices[i] * strides[i]; diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index f30339a80c8..55b5417fc15 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -245,7 +245,6 @@ class Shape { } return ret_array; } - }; inline std::ostream &operator<<(std::ostream &os, const Shape &shape) {