-
Notifications
You must be signed in to change notification settings - Fork 81
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#13745:move tensor.reshape_unsafe to ttnn.experimental
- Loading branch information
Showing
8 changed files
with
256 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
133 changes: 133 additions & 0 deletions
133
ttnn/cpp/ttnn/operations/experimental/reshape/reshape.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
|
||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "ttnn/common/constants.hpp" | ||
#include "ttnn/run_operation.hpp" | ||
#include "reshape.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include <functional> | ||
#include <ttnn/operations/numpy/functions.hpp> | ||
#include "ttnn/operations/experimental/auto_format/auto_format.hpp" | ||
#include "ttnn/tensor/tensor_utils.hpp" | ||
#include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" | ||
#include "ttnn/operations/data_movement/slice/slice.hpp" | ||
#include "ttnn/operations/core/core.hpp" | ||
|
||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
|
||
#include <cstdint> | ||
#include <memory> | ||
|
||
#include "common/bfloat16.hpp" | ||
#include "ttnn/tensor/tensor_impl.hpp" | ||
#include "ttnn/tensor/tensor_impl_wrapper.hpp" | ||
#include "ttnn/tensor/tensor_utils.hpp" | ||
#include "ttnn/tensor/types.hpp" | ||
#include "tt_metal/common/constants.hpp" | ||
#include "tt_metal/common/math.hpp" | ||
#include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" | ||
#include "tt_metal/graph/graph_tracking.hpp" | ||
#include "ttnn/distributed/api.hpp" | ||
#include "ttnn/distributed/types.hpp" | ||
#include "ttnn/core.hpp" | ||
|
||
|
||
namespace ttnn{ | ||
|
||
namespace operations::experimental::reshape { | ||
ttnn::Tensor tensor_reshape(const ttnn::Tensor& input_tensor, const ttnn::Shape& new_shape) { | ||
ZoneScoped; | ||
GraphTracker::instance().track_function_start("ttnn::experimental::reshape", input_tensor, new_shape); | ||
const auto& new_padded_shape = new_shape.padded_shape(); | ||
const auto tile = input_tensor.get_tensor_spec().tile(); | ||
TT_ASSERT( | ||
input_tensor.volume() == new_padded_shape.volume(), | ||
"{} != {}", | ||
input_tensor.volume(), | ||
new_padded_shape.volume()); | ||
if (input_tensor.get_layout() == Layout::TILE) { | ||
TT_ASSERT( | ||
new_padded_shape[-2] % tile.get_tile_shape()[0] == 0 && | ||
new_padded_shape[-1] % tile.get_tile_shape()[1] == 0 && | ||
"Expected a multiple of 32 for H, W (or -1 evaluating to such) in ttnn::experimental::reshape()!"); | ||
} | ||
auto output = std::visit( | ||
[&input_tensor, &new_shape, &tile](auto&& storage) -> Tensor { | ||
using T = std::decay_t<decltype(storage)>; | ||
const auto& tensor = input_tensor; | ||
if constexpr (std::is_same_v<T, MultiDeviceHostStorage>) { | ||
auto updated_storage = std::get<T>(tensor.get_storage()); | ||
for (int i = 0; i < updated_storage.shapes.size(); i++) { | ||
updated_storage.shapes[i] = new_shape; | ||
} | ||
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} | ||
if constexpr (std::is_same_v<T, MultiDeviceStorage>) { | ||
MultiDeviceStorage updated_storage = std::get<T>(tensor.get_storage()); | ||
std::unordered_map<int, ttnn::Shape> new_shapes; | ||
|
||
for (auto device_id : updated_storage.ordered_device_ids) { | ||
new_shapes.insert({device_id, new_shape}); | ||
} | ||
updated_storage.shapes = new_shapes; | ||
return Tensor(updated_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} | ||
if constexpr (std::is_same_v<T, DeviceStorage>) { | ||
if (input_tensor.get_layout() == Layout::ROW_MAJOR) { | ||
if (tensor.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED) { | ||
DeviceStorage device_storage = std::get<T>(tensor.get_storage()); | ||
DeviceBuffer device_buffer = device_storage.get_buffer(); | ||
device_buffer->set_page_size(new_shape[-1] * tensor.element_size()); | ||
device_storage.insert_buffer(device_buffer); | ||
return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} else { | ||
DeviceStorage device_storage = std::get<T>(tensor.get_storage()); | ||
DeviceBuffer device_buffer = device_storage.get_buffer(); | ||
ShardSpecBuffer shard_spec_buffer = device_buffer->shard_spec(); | ||
|
||
auto shard_spec = shard_spec_buffer.tensor_shard_spec; | ||
auto shard_shape = shard_spec.shape; | ||
|
||
uint32_t mul_div = new_shape[-1] > shard_shape[1] ? (new_shape[-1] / shard_shape[1]) | ||
: (shard_shape[1] / new_shape[-1]); | ||
shard_spec.shape[0] = | ||
new_shape[-1] > shard_shape[1] ? shard_shape[0] / mul_div : shard_shape[0] * mul_div; | ||
shard_spec.shape[1] = new_shape[-1]; | ||
|
||
shard_spec_buffer.page_shape = {1, new_shape[-1]}; | ||
shard_spec_buffer.tensor2d_shape = {shard_spec.shape[0], 1}; | ||
shard_spec_buffer.set_shard_spec(shard_spec); | ||
|
||
device_buffer->set_shard_spec(shard_spec_buffer); | ||
device_storage.insert_buffer(device_buffer); | ||
|
||
return Tensor(device_storage, new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} | ||
} else { | ||
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} | ||
} else { | ||
return Tensor(tensor.get_storage(), new_shape, tensor.get_dtype(), tensor.get_layout(), tile); | ||
} | ||
}, | ||
input_tensor.get_storage()); | ||
output = tt::tt_metal::set_tensor_id(output); | ||
GraphTracker::instance().track_function_end(output); | ||
return output; | ||
} | ||
|
||
|
||
|
||
ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::SimpleShape& shape) { | ||
return tensor_reshape(tensor, shape); | ||
} | ||
|
||
ttnn::Tensor ReshapeOperation::invoke(const ttnn::Tensor& tensor, const ttnn::Shape& shape) { | ||
return tensor_reshape(tensor, shape); | ||
} | ||
|
||
} // namespace operations::experimental::reshape | ||
} //namespace ttnn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "ttnn/run_operation.hpp" | ||
#include "ttnn/decorators.hpp" | ||
#include <optional> | ||
|
||
namespace ttnn { | ||
namespace operations::experimental::reshape { | ||
|
||
|
||
struct ReshapeOperation { | ||
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::Shape& shape); | ||
static ttnn::Tensor invoke(const ttnn::Tensor& input_tensor, const ttnn::SimpleShape& shape); | ||
}; | ||
|
||
} // namespace operations::experimental::reshape | ||
|
||
namespace experimental { | ||
constexpr auto reshape = | ||
ttnn::register_operation_with_auto_launch_op<"ttnn::experimental::reshape", ttnn::operations::experimental::reshape::ReshapeOperation>(); | ||
} // namespace experimental | ||
} // namespace ttnn |
72 changes: 72 additions & 0 deletions
72
ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#include "reshape_pybind.hpp" | ||
#include "reshape.hpp" | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
|
||
#include "ttnn/cpp/pybind11/decorators.hpp" | ||
|
||
#include "ttnn/types.hpp" | ||
|
||
#include "ttnn/tensor/tensor.hpp" | ||
#include "ttnn/tensor/tensor_impl.hpp" | ||
|
||
|
||
namespace ttnn::operations::experimental::reshape::detail { | ||
namespace py = pybind11; | ||
|
||
void py_bind_reshape(py::module& module) { | ||
auto doc = R"doc( | ||
Note: for a 0 cost view, the following conditions must be met: | ||
* the last dimension must not change | ||
* In Tiled the second last two dimensions must not change OR there is no padding on the second last dimension | ||
Args: | ||
* input_tensor: Input Tensor. | ||
* new_shape: New shape of tensor. | ||
Returns: | ||
ttnn.Tensor: the output tensor with the new shape. | ||
Example: | ||
>>> tensor = ttnn.from_torch(torch.tensor((1, 4), dtype=torch.bfloat16), device=device) | ||
>>> output = ttnn.experimental.reshape(tensor, (1, 1, 2, 2)) | ||
)doc"; | ||
bind_registered_operation( | ||
module, | ||
ttnn::experimental::reshape, | ||
doc, | ||
ttnn::pybind_overload_t{ | ||
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, int N, int C, int H, int W) { | ||
return self(input_tensor, infer_dims_for_reshape(input_tensor, ttnn::SmallVector<int>{N, C, H, W})); | ||
}, | ||
py::arg("input_tensor"), | ||
py::arg("N"), | ||
py::arg("C"), | ||
py::arg("H"), | ||
py::arg("W"), | ||
}, | ||
|
||
ttnn::pybind_overload_t{ | ||
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::Shape& shape) { | ||
return self(input_tensor, shape); }, | ||
py::arg("input_tensor"), | ||
py::arg("shape"), | ||
}, | ||
ttnn::pybind_overload_t{ | ||
[](const decltype(ttnn::experimental::reshape)& self, ttnn::Tensor& input_tensor, const ttnn::SmallVector<int32_t>& shape) { | ||
return self(input_tensor, infer_dims_for_reshape(input_tensor, shape)); | ||
}, | ||
py::arg("input_tensor"), | ||
py::arg("shape"), | ||
}); | ||
} | ||
|
||
} // namespace ttnn::operations::experimental::reshape::detail |
13 changes: 13 additions & 0 deletions
13
ttnn/cpp/ttnn/operations/experimental/reshape/reshape_pybind.hpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#pragma once | ||
|
||
#include "pybind11/pybind_fwd.hpp" | ||
|
||
namespace ttnn::operations::experimental::reshape::detail { | ||
|
||
void py_bind_reshape(pybind11::module& module); | ||
|
||
} // namespace ttnn::operations::experimental |