From 01163e97e7809aee22f49b6dd55ef78427b2bed0 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 29 Nov 2024 00:36:06 +0000 Subject: [PATCH] #13745:replace tensor.reshape with tensor.reshape_unsafe --- ttnn/cpp/pybind11/pytensor.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ttnn/cpp/pybind11/pytensor.cpp b/ttnn/cpp/pybind11/pytensor.cpp index 0a244cc1655..75b3893d557 100644 --- a/ttnn/cpp/pybind11/pytensor.cpp +++ b/ttnn/cpp/pybind11/pytensor.cpp @@ -19,6 +19,9 @@ #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_impl.hpp" +#include "ttnn/common/constants.hpp" +#include "ttnn/operations/core/core.hpp" + using namespace tt::tt_metal; namespace py = pybind11; @@ -1557,7 +1560,7 @@ void pytensor_module(py::module& m_tensor) { dtype = tt_tensor.get_dtype() )doc") .def( - "reshape", + "reshape_unsafe", [](Tensor& self, int N, int C, int H, int W) { return self.reshape(infer_dims_for_reshape(self, ttnn::SmallVector{N, C, H, W})); }, @@ -1569,7 +1572,7 @@ void pytensor_module(py::module& m_tensor) { reshaped_tensor = tt_tensor.reshape(N, C, H, W) )doc") .def( - "reshape", + "reshape_unsafe", [](Tensor& self, const ttnn::Shape& shape) -> Tensor { return self.reshape(shape); }, R"doc( Reshapes TT tensor @@ -1579,7 +1582,7 @@ void pytensor_module(py::module& m_tensor) { reshaped_tensor = tt_tensor.reshape((4, 3, 32)) )doc") .def( - "reshape", + "reshape_unsafe", [](Tensor& self, const ttnn::SmallVector& shape) -> Tensor { return self.reshape(infer_dims_for_reshape(self, shape)); },